asdjghh commited on
Commit
fd19f15
·
verified ·
1 Parent(s): 3c7fc5a

Upload modeling_vlm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_vlm.py +521 -0
modeling_vlm.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from math import e
21
+ import torch
22
+ from attrdict import AttrDict
23
+ from einops import rearrange
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModelForCausalLM,
27
+ LlamaConfig,
28
+ LlamaForCausalLM,
29
+ PreTrainedModel,
30
+ )
31
+ from transformers.modeling_outputs import CausalLMOutputWithPast
32
+ from torch.nn import CrossEntropyLoss
33
+ from transformers.configuration_utils import PretrainedConfig
34
+
35
+ from janus.models.clip_encoder import CLIPVisionTower
36
+ from janus.models.projector import MlpProjector
37
+
38
+
39
+
40
+ class vision_head(torch.nn.Module):
41
+ def __init__(self, params):
42
+ super().__init__()
43
+ self.output_mlp_projector = torch.nn.Linear(
44
+ params.n_embed, params.image_token_embed
45
+ )
46
+ self.vision_activation = torch.nn.GELU()
47
+ self.vision_head = torch.nn.Linear(
48
+ params.image_token_embed, params.image_token_size
49
+ )
50
+
51
+ def forward(self, x):
52
+ x = self.output_mlp_projector(x)
53
+ x = self.vision_activation(x)
54
+ x = self.vision_head(x)
55
+ return x
56
+
57
+
58
+ def model_name_to_cls(cls_name):
59
+ if "MlpProjector" in cls_name:
60
+ cls = MlpProjector
61
+
62
+ elif "CLIPVisionTower" in cls_name:
63
+ cls = CLIPVisionTower
64
+
65
+ elif "VQ" in cls_name:
66
+ from janus.models.vq_model import VQ_models
67
+
68
+ cls = VQ_models[cls_name]
69
+ elif "vision_head" in cls_name:
70
+ cls = vision_head
71
+ else:
72
+ raise ValueError(f"class_name {cls_name} is invalid.")
73
+
74
+ return cls
75
+
76
+
77
+ class VisionConfig(PretrainedConfig):
78
+ model_type = "vision"
79
+ cls: str = ""
80
+ params: AttrDict = {}
81
+
82
+ def __init__(self, **kwargs):
83
+ super().__init__(**kwargs)
84
+
85
+ self.cls = kwargs.get("cls", "")
86
+ if not isinstance(self.cls, str):
87
+ self.cls = self.cls.__name__
88
+
89
+ self.params = AttrDict(kwargs.get("params", {}))
90
+
91
+
92
+ class AlignerConfig(PretrainedConfig):
93
+ model_type = "aligner"
94
+ cls: str = ""
95
+ params: AttrDict = {}
96
+
97
+ def __init__(self, **kwargs):
98
+ super().__init__(**kwargs)
99
+
100
+ self.cls = kwargs.get("cls", "")
101
+ if not isinstance(self.cls, str):
102
+ self.cls = self.cls.__name__
103
+
104
+ self.params = AttrDict(kwargs.get("params", {}))
105
+
106
+
107
+ class GenVisionConfig(PretrainedConfig):
108
+ model_type = "gen_vision"
109
+ cls: str = ""
110
+ params: AttrDict = {}
111
+
112
+ def __init__(self, **kwargs):
113
+ super().__init__(**kwargs)
114
+
115
+ self.cls = kwargs.get("cls", "")
116
+ if not isinstance(self.cls, str):
117
+ self.cls = self.cls.__name__
118
+
119
+ self.params = AttrDict(kwargs.get("params", {}))
120
+
121
+
122
+ class GenAlignerConfig(PretrainedConfig):
123
+ model_type = "gen_aligner"
124
+ cls: str = ""
125
+ params: AttrDict = {}
126
+
127
+ def __init__(self, **kwargs):
128
+ super().__init__(**kwargs)
129
+
130
+ self.cls = kwargs.get("cls", "")
131
+ if not isinstance(self.cls, str):
132
+ self.cls = self.cls.__name__
133
+
134
+ self.params = AttrDict(kwargs.get("params", {}))
135
+
136
+
137
+ class GenHeadConfig(PretrainedConfig):
138
+ model_type = "gen_head"
139
+ cls: str = ""
140
+ params: AttrDict = {}
141
+
142
+ def __init__(self, **kwargs):
143
+ super().__init__(**kwargs)
144
+
145
+ self.cls = kwargs.get("cls", "")
146
+ if not isinstance(self.cls, str):
147
+ self.cls = self.cls.__name__
148
+
149
+ self.params = AttrDict(kwargs.get("params", {}))
150
+ from dataclasses import dataclass
151
+ @dataclass
152
+ class VLChatProcessorOutput():
153
+ sft_format: str
154
+ input_ids: torch.Tensor
155
+ pixel_values: torch.Tensor
156
+ num_image_tokens: torch.IntTensor
157
+
158
+ def __len__(self):
159
+ return len(self.input_ids)
160
+
161
+ class MultiModalityConfig(PretrainedConfig):
162
+ model_type = "multi_modality"
163
+ vision_config: VisionConfig
164
+ aligner_config: AlignerConfig
165
+
166
+ gen_vision_config: GenVisionConfig
167
+ gen_aligner_config: GenAlignerConfig
168
+ gen_head_config: GenHeadConfig
169
+
170
+ language_config: LlamaConfig
171
+
172
+ def __init__(self, **kwargs):
173
+ super().__init__(**kwargs)
174
+ vision_config = kwargs.get("vision_config", {})
175
+ self.vision_config = VisionConfig(**vision_config)
176
+
177
+ aligner_config = kwargs.get("aligner_config", {})
178
+ self.aligner_config = AlignerConfig(**aligner_config)
179
+
180
+ gen_vision_config = kwargs.get("gen_vision_config", {})
181
+ self.gen_vision_config = GenVisionConfig(**gen_vision_config)
182
+
183
+ gen_aligner_config = kwargs.get("gen_aligner_config", {})
184
+ self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
185
+
186
+ gen_head_config = kwargs.get("gen_head_config", {})
187
+ self.gen_head_config = GenHeadConfig(**gen_head_config)
188
+
189
+ language_config = kwargs.get("language_config", {})
190
+ if isinstance(language_config, LlamaConfig):
191
+ self.language_config = language_config
192
+ else:
193
+ self.language_config = LlamaConfig(**language_config)
194
+
195
+
196
+ class MultiModalityPreTrainedModel(PreTrainedModel):
197
+ config_class = MultiModalityConfig
198
+ base_model_prefix = "multi_modality"
199
+ _no_split_modules = []
200
+ _skip_keys_device_placement = "past_key_values"
201
+
202
+
203
+ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
204
+ def __init__(self, config: MultiModalityConfig):
205
+ super().__init__(config)
206
+
207
+ vision_config = config.vision_config
208
+ vision_cls = model_name_to_cls(vision_config.cls)
209
+ self.vision_model = vision_cls(**vision_config.params)
210
+
211
+ aligner_config = config.aligner_config
212
+ aligner_cls = model_name_to_cls(aligner_config.cls)
213
+ self.aligner = aligner_cls(aligner_config.params)
214
+
215
+ gen_vision_config = config.gen_vision_config
216
+ gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
217
+ self.gen_vision_model = gen_vision_cls()
218
+
219
+ gen_aligner_config = config.gen_aligner_config
220
+ gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
221
+ self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
222
+
223
+ gen_head_config = config.gen_head_config
224
+ gen_head_cls = model_name_to_cls(gen_head_config.cls)
225
+ self.gen_head = gen_head_cls(gen_head_config.params)
226
+
227
+ self.gen_embed = torch.nn.Embedding(
228
+ gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
229
+ )
230
+
231
+ language_config = config.language_config
232
+ self.language_model = LlamaForCausalLM(language_config)
233
+
234
+ def prepare_inputs_embeds(
235
+ self,
236
+ input_ids: torch.LongTensor,
237
+ pixel_values: torch.FloatTensor,
238
+ images_seq_mask: torch.LongTensor=None,
239
+ images_emb_mask: torch.LongTensor=None,
240
+ **kwargs,
241
+ ):
242
+ """
243
+
244
+ Args:
245
+ input_ids (torch.LongTensor): [b, T]
246
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
247
+ images_seq_mask (torch.BoolTensor): [b, T]
248
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
249
+
250
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
251
+
252
+ Returns:
253
+ input_embeds (torch.Tensor): [b, T, D]
254
+ """
255
+
256
+ # bs, n = pixel_values.shape[0:2]
257
+ # images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
258
+ # # [b x n, T2, D]
259
+ # images_embeds = self.aligner(self.vision_model(images))
260
+ #
261
+ # # [b x n, T2, D] -> [b, n x T2, D]
262
+ # images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
263
+ # # [b, n, T2] -> [b, n x T2]
264
+ # # images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
265
+ #
266
+ # # [b, T, D]
267
+ # # input_ids[input_ids < 0] = 0 # ignore the image embeddings
268
+ # inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
269
+ #
270
+ # # replace with the image embeddings
271
+ # # inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
272
+ #
273
+ # return inputs_embeds, images_embeds
274
+ bs, n = pixel_values.shape[0:2]
275
+ print('px.shape', pixel_values.shape)
276
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
277
+ # [b x n, T2, D]
278
+ images_embeds = self.aligner(self.vision_model(images))
279
+
280
+ # [b x n, T2, D] -> [b, n x T2, D]
281
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
282
+ # [b, n, T2] -> [b, n x T2]
283
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
284
+
285
+ # [b, T, D]
286
+ input_ids[input_ids < 0] = 0 # ignore the image embeddings
287
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
288
+
289
+ # replace with the image embeddings
290
+ print('input_ids' ,input_ids.shape)
291
+ print('images_seq_mask ',images_seq_mask.shape)
292
+ print('inputs_embeds ',inputs_embeds.shape)
293
+ print('images_embeds ',images_embeds.shape)
294
+ print('images_emb_mask ',images_emb_mask.shape)
295
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
296
+
297
+ return inputs_embeds
298
+ def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
299
+ return self.gen_aligner(self.gen_embed(image_ids))
300
+
301
+ def forward(self,vl_chat_processor,
302
+ input_ids, labels=None, task="understanding", return_dict=True, pixel_values=None, images_seq_mask=None, images_emb_mask=None, **kwargs):
303
+ if task == "understanding":
304
+ inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask)
305
+ return self.language_model.forward(
306
+ inputs_embeds=inputs_embeds,
307
+ labels=labels,
308
+ **kwargs
309
+ )
310
+
311
+ elif task == "generation":
312
+ print('LLLLLLLLLLL ',pixel_values)
313
+ print(kwargs)
314
+ image_token_num_per_image = 576
315
+ cfg_weight = 5
316
+ temperature = 1
317
+
318
+ tokens = torch.zeros((2*input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda()
319
+ for i in range(2):
320
+ tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), :] = input_ids
321
+ if i % 2 != 0:
322
+ tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), 1:-1] = 100015 # pad_id
323
+
324
+ inputs_embeds = self.language_model.get_input_embeddings()(tokens)
325
+
326
+ generated_tokens = torch.zeros((2*input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
327
+
328
+ outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None, labels=labels)
329
+
330
+ hidden_states = outputs.last_hidden_state
331
+ logits = self.gen_head(hidden_states)
332
+
333
+ logits_cond = logits[0::2, :]
334
+ logits_uncond = logits[1::2, :]
335
+
336
+ all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond)
337
+
338
+
339
+ loss_fct = CrossEntropyLoss()
340
+ shift_logits = all_logits[..., :-1, :].contiguous()
341
+ shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
342
+
343
+ if labels is not None:
344
+ shift_labels = labels[..., 1:].contiguous()
345
+ shift_labels = shift_labels.view(-1)
346
+ shift_labels = shift_labels.to(shift_logits.device)
347
+ loss = loss_fct(shift_logits, shift_labels)
348
+ else:
349
+ loss = None
350
+ if not return_dict:
351
+ output = (logits,) + outputs[1:]
352
+ return ((loss,) + output) if loss is not None else output
353
+
354
+ return CausalLMOutputWithPast(
355
+ loss=loss,
356
+ logits=logits,
357
+ past_key_values=outputs.past_key_values,
358
+ hidden_states=outputs.hidden_states,
359
+ attentions=outputs.attentions,
360
+ )
361
+
362
+ elif task == "generation_direct":
363
+ outputs = self.language_model.model(input_ids=input_ids, **kwargs)
364
+ hidden_states = outputs[0] # possibly outputs[0]
365
+ logits = self.gen_head(hidden_states)
366
+
367
+ loss = None
368
+
369
+ logits = logits.float()
370
+ # Shift so that tokens < n predict n
371
+ shift_logits = logits[..., :-1, :].contiguous()
372
+ shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
373
+
374
+ if labels is not None:
375
+ shift_labels = labels[..., 1:].contiguous()
376
+ # Flatten the tokens
377
+ loss_fct = CrossEntropyLoss()
378
+ shift_labels = shift_labels.view(-1)
379
+ # Enable model parallelism
380
+ shift_labels = shift_labels.to(shift_logits.device)
381
+ loss = loss_fct(shift_logits, shift_labels)
382
+ else:
383
+ loss = None
384
+
385
+ if not return_dict:
386
+ output = (logits,) + outputs[1:]
387
+ return ((loss,) + output) if loss is not None else output
388
+
389
+ return CausalLMOutputWithPast(
390
+ loss=loss,
391
+ logits=logits,
392
+ past_key_values=outputs.past_key_values,
393
+ hidden_states=outputs.hidden_states,
394
+ attentions=outputs.attentions,
395
+ )
396
+ elif task == "image_editing":
397
+ image_token_num_per_image = 576
398
+ img_size = 384
399
+ patch_size = 16
400
+ cfg_weight = 5
401
+ temperature = 1
402
+
403
+ tokens = torch.zeros((3 * input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda()
404
+ pre_data = []
405
+ img_len = len(kwargs['source_image'])
406
+ # print(kwargs['source_image'].size(0))
407
+ print(kwargs['source_image'])
408
+ print(len(kwargs['source_image'][0]))
409
+ import PIL.Image
410
+ images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']]
411
+ # images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']]
412
+ print('len_images : ',len(images))
413
+ encoder_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values']
414
+ print(encoder_pixel_values.shape)
415
+ print(encoder_pixel_values[0].shape)
416
+ # print((encoder_pixel_values[0]!= encoder_pixel_values[1]).sum())
417
+ # print((encoder_pixel_values[0] != encoder_pixel_values[2]).sum())
418
+ # print((encoder_pixel_values[0] != encoder_pixel_values[3]).sum())
419
+ for i in range(3 * input_ids.size(0)):
420
+ print(input_ids.shape)
421
+ print(input_ids.size(0))
422
+ tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0),:] = input_ids[i // 3,:]
423
+ if i % 3 == 2:
424
+ tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0), 1:-1] = 100015
425
+ print(encoder_pixel_values[i//3,:].shape)
426
+ print(len(kwargs['sft_format'][i//3]))
427
+ print(tokens[i].shape)
428
+ pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:],
429
+ input_ids=tokens[i - 2],
430
+ num_image_tokens=[vl_chat_processor.num_image_tokens] * 1))
431
+ pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:],
432
+ input_ids=tokens[i - 1],
433
+ num_image_tokens=[vl_chat_processor.num_image_tokens] * 1))
434
+ pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=None, input_ids=tokens[i],
435
+ num_image_tokens=[]))
436
+ # print(tokens.shape)
437
+ # _, src_image = self.prepare_inputs_embeds(tokens[0], kwargs['source_image'])
438
+ ppp = (tokens == 100580).nonzero()
439
+ # print(tokens[0][583],tokens[0][584],tokens[0][576],tokens[0][577])
440
+ # print(input_ids.size(0))
441
+ # print(tokens[0][2], tokens[0][3])
442
+ # print(tokens[0][1161], tokens[0][1162])
443
+ # print(ppp)
444
+ # print(src_image.shape)
445
+ # img_len = src_image.shape[0]
446
+ # # inputs_embeds_2 = self.language_model.get_input_embeddings()(tokens[1])
447
+ # # inputs_embeds_3 = self.language_model.get_input_embeddings()(tokens[2])
448
+ # inputs_embeds = self.language_model.get_input_embeddings()(tokens)
449
+ # print(inputs_embeds.shape)
450
+ prepare_inputs = vl_chat_processor.batchify(pre_data)
451
+ print('prepare_inputs pixel_values', prepare_inputs['pixel_values'].shape)
452
+ print('prepare_inputs images_emb_mask', prepare_inputs['images_emb_mask'].shape)
453
+ print('prepare_inputs images_seq_mask', prepare_inputs['images_seq_mask'].shape)
454
+
455
+ inputs_embeds = self.prepare_inputs_embeds(
456
+ input_ids=tokens.cuda(),
457
+ pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(),
458
+ images_emb_mask=prepare_inputs['images_emb_mask'].cuda(),
459
+ images_seq_mask=prepare_inputs['images_seq_mask'].cuda()
460
+ )
461
+
462
+ input_image_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'].to(torch.bfloat16).cuda()
463
+ quant_input, emb_loss_input, info_input = self.gen_vision_model.encode(input_image_pixel_values)
464
+ image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
465
+ image_embeds_input = self.prepare_gen_img_embeds(image_tokens_input)
466
+ for ii, ind in enumerate(ppp):
467
+ if ii % 4 == 0:
468
+ offset = ind[1] + 2
469
+ inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
470
+
471
+ generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
472
+
473
+ outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None,
474
+ labels=labels)
475
+
476
+ hidden_states = outputs.last_hidden_state
477
+ logits = self.gen_head(hidden_states)
478
+
479
+ # logits_cond = logits[0::2, :]
480
+ # logits_uncond = logits[1::2, :]
481
+
482
+ logit_cond_full = logits[0::3, :]
483
+ logit_cond_part = logits[1::3, :]
484
+ logit_uncond = logits[2::3, :]
485
+
486
+ cfg_weight2 = 5
487
+ logit_cond = (logit_cond_full + cfg_weight2 * (logit_cond_part)) / (1 + cfg_weight2)
488
+ all_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
489
+
490
+ # all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond)
491
+
492
+ loss_fct = CrossEntropyLoss()
493
+ shift_logits = all_logits[..., :-1, :].contiguous()
494
+ shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
495
+
496
+ if labels is not None:
497
+ shift_labels = labels[..., 1:].contiguous()
498
+ shift_labels = shift_labels.view(-1)
499
+ shift_labels = shift_labels.to(shift_logits.device)
500
+ loss = loss_fct(shift_logits, shift_labels)
501
+ else:
502
+ loss = None
503
+ if not return_dict:
504
+ output = (logits,) + outputs[1:]
505
+ return ((loss,) + output) if loss is not None else output
506
+
507
+ return CausalLMOutputWithPast(
508
+ loss=loss,
509
+ logits=logits,
510
+ past_key_values=outputs.past_key_values,
511
+ hidden_states=outputs.hidden_states,
512
+ attentions=outputs.attentions,
513
+ )
514
+
515
+ AutoConfig.register("vision", VisionConfig)
516
+ AutoConfig.register("aligner", AlignerConfig)
517
+ AutoConfig.register("gen_vision", GenVisionConfig)
518
+ AutoConfig.register("gen_aligner", GenAlignerConfig)
519
+ AutoConfig.register("gen_head", GenHeadConfig)
520
+ AutoConfig.register("multi_modality", MultiModalityConfig)
521
+ AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)