Fahad-S commited on
Commit
874d7d9
·
verified ·
1 Parent(s): 4a21d2d

Upload teacher_code/blip3o_fast.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. teacher_code/blip3o_fast.py +497 -0
teacher_code/blip3o_fast.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+
24
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoTokenizer
25
+
26
+ from transformers.modeling_outputs import CausalLMOutputWithPast
27
+ from transformers.generation.utils import GenerateOutput
28
+
29
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
30
+
31
+
32
+ from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN_IDX
33
+
34
+
35
+
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
38
+ import numpy as np
39
+ from diffusers.models import AutoencoderKL
40
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
41
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
42
+
43
+ def compute_prediction_divergence(teacher_pred, student_pred, method='kl'):
44
+ """
45
+ Compute divergence between teacher and student predictions
46
+ """
47
+ if method == 'kl':
48
+ # Treat predictions as parameters of Gaussian distributions
49
+ # Assume unit variance for simplicity
50
+ teacher_logits = teacher_pred.flatten(1) # [B, D]
51
+ student_logits = student_pred.flatten(1) # [B, D]
52
+
53
+ # KL divergence between two Gaussians with same variance
54
+ kl_div = 0.5 * torch.mean((teacher_logits - student_logits) ** 2)
55
+ return kl_div
56
+
57
+ elif method == 'cosine_distance':
58
+ # Cosine similarity between predictions
59
+ teacher_flat = teacher_pred.flatten(1)
60
+ student_flat = student_pred.flatten(1)
61
+
62
+ cosine_sim = F.cosine_similarity(teacher_flat, student_flat, dim=1)
63
+ cosine_distance = 1 - cosine_sim.mean()
64
+ return cosine_distance
65
+
66
+ elif method == 'js_divergence':
67
+ # Jensen-Shannon divergence approximation
68
+ teacher_flat = F.softmax(teacher_pred.flatten(1), dim=1)
69
+ student_flat = F.softmax(student_pred.flatten(1), dim=1)
70
+
71
+ m = 0.5 * (teacher_flat + student_flat)
72
+ js_div = 0.5 * F.kl_div(F.log_softmax(teacher_flat, dim=1), m, reduction='batchmean') + \
73
+ 0.5 * F.kl_div(F.log_softmax(student_flat, dim=1), m, reduction='batchmean')
74
+ return js_div
75
+
76
+
77
+ class blip3oFastConfig(Qwen2Config):
78
+ model_type = "llava_qwen2"
79
+
80
+
81
+ class blip3oFastModel(LlavaMetaModel, Qwen2Model):
82
+ config_class = blip3oFastConfig
83
+
84
+ def __init__(self, config: Qwen2Config):
85
+ super(blip3oFastModel, self).__init__(config)
86
+
87
+
88
+ class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
89
+ config_class = blip3oFastConfig
90
+
91
+ def __init__(self, config):
92
+ super(Qwen2ForCausalLM, self).__init__(config)
93
+ self.model = blip3oFastModel(config)
94
+ # self.pretraining_tp = config.pretraining_tp
95
+ self.vocab_size = config.vocab_size
96
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
97
+
98
+ # Initialize weights and apply final processing
99
+ self.post_init()
100
+ self.kd_weight=10
101
+ #self.model.dit_teacher.eval()
102
+
103
+ def get_model(self):
104
+ return self.model
105
+
106
+ def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor:
107
+ image_features = self.get_model().get_vision_tower()(pixel_values)
108
+ image_features = self.get_model().mm_projector(image_features)
109
+ return image_features
110
+
111
+ def forward(
112
+ self,
113
+ input_ids: torch.LongTensor = None,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ position_ids: Optional[torch.LongTensor] = None,
116
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
118
+ labels: Optional[torch.LongTensor] = None,
119
+ teacher_prompts: Optional[List[str]] = None,
120
+ teacher_input_ids: torch.LongTensor = None,
121
+ teacher_attention_mask: Optional[torch.Tensor] = None,
122
+ ids: Optional[list] = None,
123
+ i_s_pos: Optional[list] = None,
124
+ use_cache: Optional[bool] = None,
125
+ output_attentions: Optional[bool] = None,
126
+ output_hidden_states: Optional[bool] = None,
127
+ gen_image: Optional[torch.FloatTensor] = None,
128
+ und_image: Optional[torch.FloatTensor] = None,
129
+ grid_thw: Optional[torch.FloatTensor] = None,
130
+ image_sizes: Optional[List[List[int]]] = None,
131
+ return_dict: Optional[bool] = None,
132
+ cache_position: Optional[torch.LongTensor] = None
133
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
134
+
135
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
136
+ output_hidden_states = (
137
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
138
+ )
139
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
140
+
141
+ if inputs_embeds is None:
142
+ (
143
+ input_ids,
144
+ position_ids,
145
+ attention_mask,
146
+ past_key_values,
147
+ inputs_embeds,
148
+ labels,
149
+ latents
150
+ ) = self.prepare_inputs_labels_for_multimodal(
151
+ input_ids,
152
+ position_ids,
153
+ attention_mask,
154
+ past_key_values,
155
+ labels,
156
+ gen_image,
157
+ und_image,
158
+ grid_thw,
159
+ i_s_pos,
160
+ image_sizes
161
+ )
162
+
163
+ outputs = self.model(
164
+ input_ids=input_ids,
165
+ attention_mask=attention_mask,
166
+ position_ids=position_ids,
167
+ past_key_values=past_key_values,
168
+ inputs_embeds=inputs_embeds,
169
+ use_cache=use_cache,
170
+ output_attentions=output_attentions,
171
+ output_hidden_states=output_hidden_states,
172
+ return_dict=return_dict,
173
+ )
174
+
175
+ hidden_states = outputs[0]
176
+ logits = self.lm_head(hidden_states)
177
+ logits = logits.float()
178
+
179
+ total_loss = None
180
+ if labels is not None:
181
+ # Shift so that tokens < n predict n
182
+ shift_logits = logits[..., :-1, :].contiguous()
183
+ shift_labels = labels[..., 1:].contiguous()
184
+ # Flatten the tokens
185
+ loss_fct = torch.nn.CrossEntropyLoss()
186
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
187
+ shift_labels = shift_labels.view(-1)
188
+ # Enable model parallelism
189
+ shift_labels = shift_labels.to(shift_logits.device)
190
+ loss = loss_fct(shift_logits, shift_labels)
191
+
192
+
193
+ img_hidden_states = []
194
+ device, dtype = self.get_model().dit.device, self.get_model().dit.dtype
195
+ for b in range(hidden_states.shape[0]):
196
+ img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+8,:])
197
+ img_hidden_states = torch.stack(img_hidden_states, dim=0)
198
+ assert latents is not None, "latents should not be None"
199
+ noise = torch.randn_like(latents, device=latents.device)
200
+ weighting_scheme = "uniform"
201
+ u = compute_density_for_timestep_sampling(
202
+ weighting_scheme=weighting_scheme,
203
+ batch_size=latents.shape[0],
204
+ logit_mean=0.0,
205
+ logit_std=1.0,
206
+ mode_scale=1.29,
207
+ )
208
+ indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
209
+ timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
210
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
211
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
212
+ diffusion_pred = self.get_model().dit(
213
+ hidden_states=noisy_latents,
214
+ timestep=timesteps,
215
+ encoder_hidden_states=self.get_model().diffusion_connector(self.mask_drop(img_hidden_states)),
216
+ encoder_attention_mask=None,
217
+ ).sample
218
+ with torch.no_grad():
219
+ all_prompt_embeds = self.get_model().text_encoder(teacher_input_ids, attention_mask=teacher_attention_mask)
220
+ prompt_embeds = all_prompt_embeds[0].to(device).to(dtype)
221
+
222
+ teacher_noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
223
+ teacher_diffusion_pred = self.get_model().dit_teacher(
224
+ hidden_states=teacher_noisy_latents,
225
+ timestep=timesteps,
226
+ encoder_hidden_states=prompt_embeds,
227
+ encoder_attention_mask=teacher_attention_mask,
228
+ )[0]
229
+
230
+
231
+ target = noise - latents
232
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
233
+ diff_loss = torch.mean(
234
+ (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
235
+ 1,
236
+ )
237
+ kd_loss = F.mse_loss(diffusion_pred.float(), teacher_diffusion_pred.float())
238
+ diff_loss = diff_loss.mean()
239
+ total_loss = diff_loss + self.kd_weight * kd_loss
240
+ self._last_diff_loss = diff_loss.detach()
241
+ self._last_kd_loss = kd_loss.detach()
242
+
243
+ # In your forward method:
244
+ with torch.no_grad():
245
+ pred_divergence = compute_prediction_divergence(teacher_diffusion_pred, diffusion_pred, method='kl')
246
+ self._last_pred_divergence = pred_divergence.detach()
247
+
248
+ return CausalLMOutputWithPast(
249
+ loss=total_loss,
250
+ logits=logits,
251
+ past_key_values=outputs.past_key_values,
252
+ hidden_states=outputs.hidden_states,
253
+ attentions=outputs.attentions,
254
+ )
255
+
256
+
257
+ @torch.no_grad()
258
+ def generate(
259
+ self,
260
+ inputs: Optional[torch.Tensor] = None,
261
+ images: Optional[torch.Tensor] = None,
262
+ image_sizes: Optional[torch.Tensor] = None,
263
+ **kwargs,
264
+ ) -> Union[GenerateOutput, torch.LongTensor]:
265
+ position_ids = kwargs.pop("position_ids", None)
266
+ attention_mask = kwargs.pop("attention_mask", None)
267
+ if "inputs_embeds" in kwargs:
268
+ raise NotImplementedError("`inputs_embeds` is not supported")
269
+
270
+ if images is not None:
271
+ (
272
+ inputs,
273
+ position_ids,
274
+ attention_mask,
275
+ _,
276
+ inputs_embeds,
277
+ img_indicator,
278
+ _
279
+ ) = self.prepare_inputs_labels_for_understanding(
280
+ inputs,
281
+ position_ids,
282
+ attention_mask,
283
+ None,
284
+ None,
285
+ images,
286
+ image_sizes=image_sizes
287
+ )
288
+ else:
289
+ inputs_embeds = self.get_model().embed_tokens(inputs)
290
+
291
+ return super().generate(
292
+ position_ids=position_ids,
293
+ attention_mask=attention_mask,
294
+ inputs_embeds=inputs_embeds,
295
+ **kwargs
296
+ )
297
+
298
+ @torch.no_grad()
299
+ def generate_image(
300
+ self,
301
+ text: List[str],
302
+ tokenizer: AutoTokenizer,
303
+ pixel_values: Optional[torch.Tensor] = None,
304
+ image_grid_thw: Optional[torch.Tensor] = None,
305
+ max_var: Optional[float] = None,
306
+ ):
307
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
308
+
309
+
310
+ N_QUERY = self.get_n_query()
311
+ inputs = tokenizer(text, padding="longest", return_tensors="pt")
312
+ device = self.get_model().device
313
+ attention_mask = inputs.attention_mask.to(device)
314
+ input_ids = inputs.input_ids.to(device) # B x N
315
+ input_ids = torch.cat([input_ids, torch.tensor([[DEFAULT_IM_START_TOKEN_IDX]]).to(device)], dim=1)
316
+
317
+
318
+ text_embeds = self.get_model().embed_tokens(input_ids)
319
+ latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
320
+
321
+
322
+ if pixel_values is not None:
323
+ und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
324
+ pixel_values = pixel_values.type(self.visual.dtype)
325
+ und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
326
+ text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
327
+
328
+
329
+ text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
330
+ attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
331
+
332
+
333
+ outputs = self.model(
334
+ inputs_embeds=text_embeds,
335
+ attention_mask=attention_mask,
336
+ output_hidden_states=True,
337
+ return_dict=True,
338
+ )
339
+ hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
340
+ img_hidden_states = hidden_states
341
+ output_img = self.sample_images(img_hidden_states, scheduler)
342
+ output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
343
+
344
+ return output_img
345
+
346
+ def sample_images(
347
+ self,
348
+ img_hidden_states,
349
+ scheduler,
350
+ guidance_scale: float = 3.0,
351
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
352
+ num_inference_steps: int = 30,
353
+ num_images_per_prompt: int = 1,
354
+ return_tensor=False,
355
+ **kwargs,
356
+ ):
357
+
358
+ device = img_hidden_states.device
359
+ dtype = img_hidden_states.dtype
360
+
361
+
362
+ img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
363
+ img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
364
+
365
+ batch_size = img_hidden_states.shape[0]
366
+ latent_size = self.get_model().dit.config.input_size
367
+ latent_channels = self.get_model().dit.config.in_channels
368
+
369
+ latents = randn_tensor(
370
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
371
+ generator=generator,
372
+ device=device,
373
+ dtype=dtype,
374
+ )
375
+
376
+ # set step values
377
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
378
+ scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
379
+
380
+ # Repeat z_latents and conditions for each image per prompt
381
+ img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
382
+
383
+ for t in scheduler.timesteps:
384
+ latent_model_input = latents.repeat(2, 1, 1, 1)
385
+ if hasattr(scheduler, "scale_model_input"):
386
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
387
+
388
+ # predict noise model_output
389
+ noise_pred = self.get_model().dit(
390
+ x=latent_model_input,
391
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
392
+ z_latents=img_hidden_states_input,
393
+ )
394
+
395
+ # perform guidance
396
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
397
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
398
+
399
+ # compute previous image: x_t -> x_t-1
400
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
401
+
402
+ # samples = self.decode_latents(latents, return_tensor=return_tensor)
403
+ # breakpoint()
404
+ return latents
405
+
406
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
407
+ if isinstance(self.get_model().vae, AutoencoderKL):
408
+ latents = latents / self.get_model().vae.config.scaling_factor
409
+ if self.get_model().vae.config.shift_factor is not None:
410
+ latents = latents + self.get_model().vae.config.shift_factor
411
+ latents = latents.to(dtype=torch.float32)
412
+ samples = self.get_model().vae.decode(latents).sample
413
+ else:
414
+ samples = self.get_model().vae.decode(latents)
415
+ if normalize:
416
+ samples = (samples / 2 + 0.5).clamp(0, 1)
417
+ else:
418
+ samples = samples.clamp(-1, 1)
419
+ if return_tensor:
420
+ return samples
421
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
422
+ samples = numpy_to_pil(samples)
423
+ return samples
424
+
425
+ def prepare_and_encode_inputs(
426
+ self,
427
+ inputs: List[str | Image.Image],
428
+ tokenizer: AutoTokenizer,
429
+ do_classifier_free_guidance: bool = False,
430
+ ):
431
+ # pdb.set_trace()
432
+ device = self.get_model().device
433
+ dtype = self.get_model().dtype
434
+
435
+ has_image, has_text = False, False
436
+ text_prompt, image_prompt = "", []
437
+ img_processor = self.get_vision_tower().image_processor
438
+ negative_prompt = {}
439
+
440
+ for x in inputs:
441
+ if isinstance(x, str):
442
+ has_text = True
443
+ text_prompt += x
444
+ else:
445
+ has_image = True
446
+ text_prompt += DEFAULT_IMAGE_TOKEN
447
+ image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
448
+ if len(image_prompt) == 0:
449
+ image_prompt = None
450
+ else:
451
+ image_prompt = torch.cat(image_prompt)
452
+ image_prompt = image_prompt.type(dtype).to(device)
453
+
454
+ if has_image and not has_text:
455
+ prompt = self.encode_images(image_prompt)
456
+ if do_classifier_free_guidance:
457
+ key = "[NULL_IMAGE]"
458
+ if key not in negative_prompt:
459
+ negative_image = torch.zeros_like(image_prompt)
460
+ negative_prompt[key] = self.encode_images(negative_image)
461
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
462
+ else:
463
+ prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
464
+ if do_classifier_free_guidance:
465
+ key = ""
466
+ if key not in negative_prompt:
467
+ negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
468
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
469
+
470
+ gen_pooling = self.get_gen_pooling()
471
+ n_query = self.get_n_query()
472
+ num_img, _, c = prompt.shape
473
+ if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
474
+ stride = int(gen_pooling.split('_')[1])
475
+ sqrt_n = int(n_query**0.5)
476
+ prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
477
+ prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
478
+ prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
479
+ return prompt
480
+
481
+
482
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
483
+ inputs_embeds=None, **kwargs):
484
+ images = kwargs.pop("images", None)
485
+ image_sizes = kwargs.pop("image_sizes", None)
486
+ inputs = super().prepare_inputs_for_generation(
487
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
488
+ )
489
+ if images is not None:
490
+ inputs['images'] = images
491
+ if image_sizes is not None:
492
+ inputs['image_sizes'] = image_sizes
493
+ return inputs
494
+
495
+
496
+ AutoConfig.register("llava_qwen2", blip3oFastConfig)
497
+ AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM)