File size: 3,387 Bytes
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""molmo.py.

File for providing the Molmo model implementation.
"""
import logging

import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

from src.models.base import ModelBase
from src.models.config import Config


class MolmoModel(ModelBase):
    """Molmo model implementation."""

    def __init__(self, config: Config) -> None:
        """Initialization of the molmo model.

        Args:
            config (Config): Parsed config
        """
        # initialize the parent class
        super().__init__(config)

    def _load_specific_model(self) -> None:
        """Overridden function to populate self.model."""
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path, **getattr(self.config, 'model', {}), trust_remote_code=True
        )

    def _init_processor(self) -> None:
        """Initializes the processor."""
        self.processor = AutoProcessor.from_pretrained(
            self.config.model_path, **getattr(self.config, 'model', {}), trust_remote_code=True
        )

    def _generate_prompt(self, prompt: str, add_generation_prompt: bool = True, has_images: bool = False) -> str:
        """Generates the Molmo model prompt which will not use the chat template.

        [Note from Martin] I'd hack these parameters a bit for gradio, follow Base.

        Args:
            prompt (str): The prompt to return, set by the config.
            add_generation_prompt (bool): Whether to add a start token of a bot
                response.
            has_images (bool): Whether the model has images or not.

        Returns:
            str: The prompt to return, set by the config.
        """
        return prompt

    def _generate_processor_output(self, prompt: str, img_path: str) -> dict:
        """Generate the processor argument to be input into the processor.

        Args:
            prompt (str): The generated prompt string with the input text and
                the image labels.
            img_path (str): The specified image path.

        Returns:
            dict: The corresponding processor arguments per image and prompt.

        Raises:
            ValueError: If no prompt is provided when required.
        """
        if img_path is None:
            raise ValueError('Molmo cannot have text-only generation.')

        # prepare the data inputs according to
        # https://huggingface.co/allenai/Molmo-7B-D-0924
        data_inputs = self.processor.process(
            images=[Image.open(img_path)],
            text=prompt
        )

        # move inputs to the correct device and make a batch of size 1
        return {
            k: v.to(self.config.device).unsqueeze(0)
            for k, v in data_inputs.items()
        }

    def _forward(self, data: dict) -> None:
        """Given some input data, performs a single forward pass.

        This function itself can be overriden, while _hook_and_eval
        should be left in tact.

        Args:
            data (dict): The given data tensor.
        """
        generation_config = self.config.forward
        with torch.no_grad():
            _ = self.model.generate_from_batch(
                data,
                GenerationConfig(**generation_config),
                tokenizer=self.processor.tokenizer
            )
        logging.debug('Completed forward pass...')