File size: 13,105 Bytes
0c085bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch
import torch.distributed as dist
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PretrainedConfig,
)

import loguru
from .attention_mask import make_mask
from .configuration_vora import VoRAConfig
from .vision_embedding import *  # hacking, let transformers find vision_embedding
from . import vision_embedding as VB
from .lora import apply_lora
from .vora_generation_utils import (
    VoraGenerationMixin,
    custom_prepare_4d_causal_attention_mask_with_cache_position,
)

try:
    from utils import logging
except:
    from transformers.utils import logging


logger = logging.get_logger(__name__)


class VoRAForCausalLM(PreTrainedModel):
    config_class = VoRAConfig
    _auto_class = 'AutoModelForCausalLM'
    supports_gradient_checkpointing = True
    supports_report_metrics: bool = True

    def __init__(self, config: PretrainedConfig = VoRAConfig()):
        super().__init__(config)
        self.config = config
        # -------------- Setup LLM ---------------------
        self.llm = AutoModelForCausalLM.from_pretrained(config.llm)
        self.tokenizer = AutoTokenizer.from_pretrained(config.llm)
        self.llm.__class__ = type(self.llm.__class__.__name__, (self.llm.__class__, VoraGenerationMixin), {})
        self.llm.model._prepare_4d_causal_attention_mask_with_cache_position = staticmethod(custom_prepare_4d_causal_attention_mask_with_cache_position)

        self.config.update(self.llm.config.to_dict())

        # -------------- Setup LoRA -------------------
        if config.lora:
            for _, param in self.llm.named_parameters():
                param.requires_grad = False
            apply_lora(self.llm, config.lora)
        # ----------------------------------------------

        # ------------ Setup Vision Embedding ----------
        self.vision_embedding = getattr(VB, config.vision_embedding)(self.config)  # setup after llm so that we know the hiddensize
        # ----------------------------------------------

        # ------------- Setup Aux Vision ---------------
        self.enable_aux_vision = False
        if config.aux_vision:
            from .aux_vision import AuxVision
            self.enable_aux_vision = True
            self.aux_vision = AuxVision(self.config)
            if config.reuse_aux_vision_embedding_layers:
                weights = getattr(self.aux_vision.aux_model, config.reuse_aux_vision_embedding_layers).state_dict()
                msg = self.vision_embedding.load_state_dict(weights, strict=False)
                msg = self.vision_embedding.patchifier.load_state_dict(weights, strict=False)
                logger.info(f"Loaded aux vision weights: {msg}")
        # ----------------------------------------------
        # print trainable prameters and total parameters so that we can check if we are loading the correct model
        logger.info("Trainable parameters:")
        for name, param in self.named_parameters():
            if param.requires_grad:
                logger.info(f"{name}: {param.numel()}")
        logger.info(f"Total parameters: {sum(p.numel() for p in self.parameters())}")

    def detach_and_gather_loss(self, loss, dtype, device):
        if not dist.is_initialized():
            return loss.item()
        gathered_loss = [torch.tensor(0.0, dtype=loss.dtype).to(device) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_loss, loss.detach().clone())
        avg_gathered_loss = torch.mean(torch.stack(gathered_loss))
        return avg_gathered_loss.item()

    def _encode_vision(self, images, n_frames):
        # TODO: we need a more elegant way here to deal with mixed image and pure text training
        if images.size(0) > 0:
            vision_embeds = self.vision_embedding(images)
        else:
            # FIXME: hacking for deepspeed training
            # we feed a dummy image tensor (1, 3, H, W) into vision_encoder when training a pure-text batch
            images = images.new_zeros((1, *images.shape[1:]))
            vision_embeds = self.vision_embedding(images)[0:0]
        vision_embeds = vision_embeds.split(n_frames, dim=0)
        attention_mask = [torch.ones(feature.size()[:-1], dtype=torch.long).to(feature.device) for feature in vision_embeds]
        vision_targets = [torch.ones(feature.size(), dtype=torch.long).to(feature.device).fill_(-100) for feature in attention_mask]

        image_shapes = images.shape[-2:]

        return vision_embeds, attention_mask, vision_targets, image_shapes

    def _concat_embedding(self, vision_encode_out, batch, vision_placeholder_index, left_padding=False):
        """ concat vision and text
        """

        vision_embeds, vision_atts, vision_targets, _ = vision_encode_out

        input_embeds = []
        attention_mask = []
        targets = []
        vision_mask = []  # set vision token as 1, text token as 0

        for cur_batch_idx, cur_input_ids in enumerate(batch["input_ids"]):
            cur_vision_embeds = vision_embeds[cur_batch_idx]
            cur_vision_attn = vision_atts[cur_batch_idx]
            cur_vision_targets = vision_targets[cur_batch_idx]
            cur_attn_masks = batch["attention_mask"][cur_batch_idx]

            image_token_indices = torch.where(cur_input_ids == vision_placeholder_index)[0]
            cur_image_num = len(image_token_indices)
            image_token_indices = list(image_token_indices) + [cur_input_ids.shape[0]]

            cur_input_embeds = []
            cur_attention_mask = []
            cur_target = []
            cur_vision_mask = []

            # convert text before 1st <image> to embedding
            image_token_index = image_token_indices[0]

            cur_input_embeds.append(
                self.llm.get_input_embeddings()(cur_input_ids[:image_token_index]),
            )
            cur_attention_mask.append(
                cur_attn_masks[:image_token_index],
            )
            cur_vision_mask.append(
                torch.zeros_like(cur_attn_masks[:image_token_index]).to(cur_attn_masks.device),
            )
            if "labels" in batch:
                cur_target.append(
                    batch["labels"][cur_batch_idx, :image_token_index],
                )

            if batch.get("vison_placeholder_mode", 0) == 1:
                assert cur_image_num <= 1, "multiple video input is not supported"
                cur_vision_embeds = cur_vision_embeds.unsqueeze(0)
                cur_vision_attn = cur_vision_attn.unsqueeze(0)
                cur_vision_targets = cur_vision_targets.unsqueeze(0)
            assert cur_image_num == len(cur_vision_embeds), \
                f"Size mismatch! cur_image_num: {cur_image_num}, len(cur_vision_embeds): {len(cur_vision_embeds)} {len(cur_vision_embeds)} \
                    in {batch['prompt'][cur_batch_idx]} & {batch['gt'][cur_batch_idx]} & {batch['input_ids'][cur_batch_idx]}"
            # convert each <image> xxx group into embedding
            text_embedding = self.llm.get_input_embeddings()(cur_input_ids.relu())
            for i in range(0, cur_image_num):
                image_token_index = image_token_indices[i]
                cur_input_embeds.extend([
                    cur_vision_embeds[i],
                    text_embedding[image_token_index+1:image_token_indices[i+1]]
                ])
                cur_attention_mask.extend([
                    cur_vision_attn[i],
                    cur_attn_masks[image_token_index+1:image_token_indices[i+1]]
                ])
                cur_vision_mask.extend([
                    torch.ones_like(cur_vision_attn[i]).to(cur_vision_attn[i].device),
                    torch.zeros_like(cur_attn_masks[image_token_index+1:image_token_indices[i+1]]).to(cur_vision_attn[i].device),
                ])
                if "labels" in batch:
                    cur_target.extend([
                        cur_vision_targets[i],
                        batch["labels"][cur_batch_idx, image_token_index+1:image_token_indices[i+1]],
                    ])

            input_embeds.append(torch.cat(cur_input_embeds))
            attention_mask.append(torch.cat(cur_attention_mask))
            vision_mask.append(torch.cat(cur_vision_mask))
            if "labels" in batch:
                targets.append(torch.cat(cur_target))

        # padding
        n_tokens = [embed.shape[0] for embed in input_embeds]

        max_token = max(n_tokens)

        for i in range(len(input_embeds)):
            if max_token > n_tokens[i]:
                self.pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                pad_token = torch.tensor([self.pad_id] * (max_token - n_tokens[i]))
                pad_embedding = self.llm.get_input_embeddings()(pad_token.to(batch["attention_mask"][i].device))
                pad_attention = torch.zeros(pad_embedding.shape[0], dtype=torch.long).to(batch["attention_mask"][i].device)
                pad_targets = torch.ones(pad_attention.size(), dtype=torch.long).to(batch["attention_mask"][i].device).fill_(-100)

                if left_padding:
                    input_embeds[i] = torch.cat([pad_embedding, input_embeds[i]])
                    attention_mask[i] = torch.cat([pad_attention, attention_mask[i]])
                    vision_mask[i] = torch.cat([pad_attention, vision_mask[i]])
                    if "labels" in batch:
                        targets[i] = torch.cat([pad_targets, targets[i]])
                else:
                    input_embeds[i] = torch.cat([input_embeds[i], pad_embedding])
                    attention_mask[i] = torch.cat([attention_mask[i], pad_attention])
                    vision_mask[i] = torch.cat([vision_mask[i], pad_attention])
                    if "labels" in batch:
                        targets[i] = torch.cat([targets[i], pad_targets])

        inputs_embeds = torch.stack(input_embeds, dim=0).type(self.llm.dtype)
        attention_mask = torch.stack(attention_mask, dim=0)
        vision_mask = torch.stack(vision_mask, dim=0).to(attention_mask.device)

        if len(targets) > 0:
            targets = torch.stack(targets, dim=0)

        attention_mask = make_mask(
            attention_mask,
            mode=self.config.vision_attention_mask,
            vision_mask=vision_mask,
            dtype=inputs_embeds.dtype
        )

        return inputs_embeds, attention_mask, targets, vision_mask

    def forward(self, **batch):
        # -------------- Vision/Text Embedding ----------
        vision_placeholder_index = batch.pop("vision_placeholder_index")
        images, n_frames = batch["frames"], batch["n_frames"]
        vision_encode_out = self._encode_vision(images, n_frames)
        inputs_embeds, attention_mask, targets, vision_mask = self._concat_embedding(
            vision_encode_out, batch, vision_placeholder_index)
        # -----------------------------------------------

        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=targets,
            return_dict=True,
            output_hidden_states=True,
        )

        llm_loss = outputs.loss
        device = llm_loss.device
        dtype = llm_loss.dtype

        metrics = {}

        metrics["llm_loss"] = self.detach_and_gather_loss(llm_loss, dtype, device)
        if self.enable_aux_vision:
            if images.size(0) > 0:
                aux_losses = self.aux_vision(images, outputs.hidden_states, vision_mask)
            else:
                # FIXME: hacking for deepspeed training
                aux_losses = {key: torch.tensor(0., dtype=dtype).to(device) for key in self.aux_vision.loss_keys}

            aux_loss = torch.tensor(0., dtype=dtype).to(device)
            n_aux = 0
            for _aux_key, _aux_loss in aux_losses.items():
                aux_loss += _aux_loss
                n_aux += 1
                metrics[_aux_key] = self.detach_and_gather_loss(_aux_loss, dtype, device)
            aux_loss /= n_aux

            outputs.loss = aux_loss + llm_loss
        metrics["total_loss"] = self.detach_and_gather_loss(outputs.loss, dtype, device)
        self.report_metrics(**metrics)

        return outputs

    def generate(self, batch, **generate_params):

        with torch.amp.autocast(
            enabled=(self.device != torch.device("cpu")),
            device_type=self.device.type,
        ):
            # get vision token
            vision_placeholder_index = batch.pop("vision_placeholder_index")

            # get vision features
            images, n_frames = batch["frames"], batch["n_frames"]
            vision_encode_out = self._encode_vision(images, n_frames)

            inputs_embeds, attention_mask, _, _ = self._concat_embedding(
                vision_encode_out, batch, vision_placeholder_index, left_padding=False)

        outputs = self.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            output_attentions=True,
            **generate_params
        )

        return outputs