File size: 5,174 Bytes
f471fb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import sys
import importlib
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput

from .configuration_m2_encoder import M2EncoderConfig


@dataclass
class M2EncoderOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    text_embeds: Optional[torch.FloatTensor] = None
    image_embeds: Optional[torch.FloatTensor] = None
    logits_per_image: Optional[torch.FloatTensor] = None
    logits_per_text: Optional[torch.FloatTensor] = None


class M2EncoderModel(PreTrainedModel):
    config_class = M2EncoderConfig
    base_model_prefix = "m2_encoder"
    main_input_name = "pixel_values"

    def __init__(self, config: M2EncoderConfig):
        super().__init__(config)
        model_dir = getattr(config, "_model_dir", None)
        if model_dir is None:
            raise ValueError(
                "M2EncoderConfig is missing `_model_dir`. Use "
                "`M2EncoderModel.from_pretrained(...)` so the checkpoint path can be resolved."
            )
        if model_dir not in sys.path:
            sys.path.insert(0, model_dir)

        vlmo_default_config = importlib.import_module("vlmo.config").config
        VLMo = importlib.import_module("vlmo.modules").VLMo

        vlmo_config = vlmo_default_config()
        vlmo_config.update(config.to_vlmo_overrides(model_dir))
        load_path = vlmo_config["load_path"]
        use_safetensors = load_path.endswith(".safetensors")
        if use_safetensors:
            vlmo_config["load_path"] = ""

        if vlmo_config["flash_attn"]:
            patch_torch_scale_with_flash_attn = importlib.import_module(
                "vlmo.utils.patch_utils"
            ).patch_torch_scale_with_flash_attn
            patch_torch_scale_with_flash_attn()

        self.model = VLMo(vlmo_config)
        if use_safetensors:
            state_dict = load_file(load_path)
            self.model.load_state_dict(state_dict, strict=False)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        *model_args,
        config: Optional[M2EncoderConfig] = None,
        **kwargs,
    ):
        model_dir = pretrained_model_name_or_path
        if not os.path.isdir(model_dir):
            model_dir = snapshot_download(repo_id=pretrained_model_name_or_path)

        if config is None:
            config = M2EncoderConfig.from_pretrained(model_dir, **kwargs)
        checkpoint_path = os.path.join(
            model_dir,
            kwargs.pop("m2_checkpoint_name", config.model_file),
        )
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(
                f"Missing M2-Encoder checkpoint: {checkpoint_path}"
            )
        config._model_dir = model_dir
        return cls(config, *model_args)

    def get_text_features(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.LongTensor,
    ) -> torch.FloatTensor:
        outputs = self.model.infer_text(
            {
                "text_ids": input_ids,
                "text_masks": attention_mask,
                "text_labels": None,
            }
        )
        return outputs["cls_vlffn_feats"]

    def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
        outputs = self.model.infer_image({"image": [pixel_values]})
        return outputs["cls_vlffn_feats"]

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = True,
        **kwargs,
    ) -> Union[M2EncoderOutput, Tuple[torch.FloatTensor, ...]]:
        text_embeds = None
        image_embeds = None

        if input_ids is not None:
            if attention_mask is None:
                attention_mask = torch.ones_like(input_ids)
            text_embeds = self.get_text_features(
                input_ids=input_ids, attention_mask=attention_mask
            )

        if pixel_values is not None:
            image_embeds = self.get_image_features(pixel_values=pixel_values)

        logits_per_image = None
        logits_per_text = None
        if image_embeds is not None and text_embeds is not None:
            logit_scale = self.model.logit_scale.exp()
            logits_per_image = logit_scale * image_embeds @ text_embeds.t()
            logits_per_text = logits_per_image.t()

        if not return_dict:
            return tuple(
                value
                for value in (
                    text_embeds,
                    image_embeds,
                    logits_per_image,
                    logits_per_text,
                )
                if value is not None
            )

        return M2EncoderOutput(
            text_embeds=text_embeds,
            image_embeds=image_embeds,
            logits_per_image=logits_per_image,
            logits_per_text=logits_per_text,
        )