Fahad-S commited on
Commit
16382bd
·
verified ·
1 Parent(s): fc156dd

Upload noqueries_code/blip3o_fast.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. noqueries_code/blip3o_fast.py +422 -0
noqueries_code/blip3o_fast.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+
24
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoTokenizer
25
+
26
+ from transformers.modeling_outputs import CausalLMOutputWithPast
27
+ from transformers.generation.utils import GenerateOutput
28
+
29
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
30
+
31
+
32
+ from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN_IDX
33
+
34
+
35
+
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
38
+ import numpy as np
39
+ from diffusers.models import AutoencoderKL
40
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
41
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
42
+
43
+
44
+ class blip3oFastConfig(Qwen2Config):
45
+ model_type = "llava_qwen2"
46
+
47
+
48
+ class blip3oFastModel(LlavaMetaModel, Qwen2Model):
49
+ config_class = blip3oFastConfig
50
+
51
+ def __init__(self, config: Qwen2Config):
52
+ super(blip3oFastModel, self).__init__(config)
53
+
54
+
55
+ class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
56
+ config_class = blip3oFastConfig
57
+
58
+ def __init__(self, config):
59
+ super(Qwen2ForCausalLM, self).__init__(config)
60
+ self.model = blip3oFastModel(config)
61
+ # self.pretraining_tp = config.pretraining_tp
62
+ self.vocab_size = config.vocab_size
63
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
64
+
65
+ # Initialize weights and apply final processing
66
+ self.post_init()
67
+
68
+ def get_model(self):
69
+ return self.model
70
+
71
+ def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor:
72
+ image_features = self.get_model().get_vision_tower()(pixel_values)
73
+ image_features = self.get_model().mm_projector(image_features)
74
+ return image_features
75
+
76
+
77
+ def forward(
78
+ self,
79
+ input_ids: torch.LongTensor = None,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ position_ids: Optional[torch.LongTensor] = None,
82
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
83
+ inputs_embeds: Optional[torch.FloatTensor] = None,
84
+ labels: Optional[torch.LongTensor] = None,
85
+ ids: Optional[list] = None,
86
+ i_s_pos: Optional[list] = None,
87
+ use_cache: Optional[bool] = None,
88
+ output_attentions: Optional[bool] = None,
89
+ output_hidden_states: Optional[bool] = None,
90
+ gen_image: Optional[torch.FloatTensor] = None,
91
+ und_image: Optional[torch.FloatTensor] = None,
92
+ grid_thw: Optional[torch.FloatTensor] = None,
93
+ image_sizes: Optional[List[List[int]]] = None,
94
+ return_dict: Optional[bool] = None,
95
+ cache_position: Optional[torch.LongTensor] = None
96
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
97
+
98
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
99
+ output_hidden_states = (
100
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
101
+ )
102
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
103
+
104
+ if inputs_embeds is None:
105
+ (
106
+ input_ids,
107
+ position_ids,
108
+ attention_mask,
109
+ past_key_values,
110
+ inputs_embeds,
111
+ labels,
112
+ latents
113
+ ) = self.prepare_inputs_labels_for_multimodal(
114
+ input_ids,
115
+ position_ids,
116
+ attention_mask,
117
+ past_key_values,
118
+ labels,
119
+ gen_image,
120
+ und_image,
121
+ grid_thw,
122
+ i_s_pos,
123
+ image_sizes
124
+ )
125
+
126
+ outputs = self.model(
127
+ input_ids=input_ids,
128
+ attention_mask=attention_mask,
129
+ position_ids=position_ids,
130
+ past_key_values=past_key_values,
131
+ inputs_embeds=inputs_embeds,
132
+ use_cache=use_cache,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ )
137
+
138
+ hidden_states = outputs[0]
139
+ logits = self.lm_head(hidden_states)
140
+ logits = logits.float()
141
+ total_loss = None
142
+ img_hidden_states = hidden_states
143
+ img_hidden_states = self.get_model().down_projector(img_hidden_states)
144
+ assert latents is not None, "Currently we only support image loss when latents is None"
145
+ noise = torch.randn_like(latents, device=latents.device)
146
+ weighting_scheme = "uniform"
147
+ u = compute_density_for_timestep_sampling(
148
+ weighting_scheme=weighting_scheme,
149
+ batch_size=latents.shape[0],
150
+ logit_mean=0.0,
151
+ logit_std=1.0,
152
+ mode_scale=1.29,
153
+ )
154
+ indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
155
+ timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
156
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
157
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
158
+ diffusion_pred = self.get_model().dit(
159
+ hidden_states=noisy_latents,
160
+ timestep=timesteps,
161
+ encoder_hidden_states=self.get_model().diffusion_connector(img_hidden_states),
162
+ encoder_attention_mask=attention_mask,
163
+ ).sample
164
+ target = noise - latents
165
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
166
+ diff_loss = torch.mean(
167
+ (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
168
+ 1,
169
+ )
170
+ diff_loss = diff_loss.mean()
171
+ total_loss = diff_loss
172
+
173
+ return CausalLMOutputWithPast(
174
+ loss=total_loss,
175
+ logits=logits,
176
+ past_key_values=outputs.past_key_values,
177
+ hidden_states=outputs.hidden_states,
178
+ attentions=outputs.attentions,
179
+ )
180
+
181
+
182
+ @torch.no_grad()
183
+ def generate(
184
+ self,
185
+ inputs: Optional[torch.Tensor] = None,
186
+ images: Optional[torch.Tensor] = None,
187
+ image_sizes: Optional[torch.Tensor] = None,
188
+ **kwargs,
189
+ ) -> Union[GenerateOutput, torch.LongTensor]:
190
+ position_ids = kwargs.pop("position_ids", None)
191
+ attention_mask = kwargs.pop("attention_mask", None)
192
+ if "inputs_embeds" in kwargs:
193
+ raise NotImplementedError("`inputs_embeds` is not supported")
194
+
195
+ if images is not None:
196
+ (
197
+ inputs,
198
+ position_ids,
199
+ attention_mask,
200
+ _,
201
+ inputs_embeds,
202
+ img_indicator,
203
+ _
204
+ ) = self.prepare_inputs_labels_for_understanding(
205
+ inputs,
206
+ position_ids,
207
+ attention_mask,
208
+ None,
209
+ None,
210
+ images,
211
+ image_sizes=image_sizes
212
+ )
213
+ else:
214
+ inputs_embeds = self.get_model().embed_tokens(inputs)
215
+
216
+ return super().generate(
217
+ position_ids=position_ids,
218
+ attention_mask=attention_mask,
219
+ inputs_embeds=inputs_embeds,
220
+ **kwargs
221
+ )
222
+
223
+ @torch.no_grad()
224
+ def generate_image(
225
+ self,
226
+ text: List[str],
227
+ tokenizer: AutoTokenizer,
228
+ pixel_values: Optional[torch.Tensor] = None,
229
+ image_grid_thw: Optional[torch.Tensor] = None,
230
+ max_var: Optional[float] = None,
231
+ ):
232
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
233
+
234
+
235
+ N_QUERY = self.get_n_query()
236
+ inputs = tokenizer(text, padding="longest", return_tensors="pt")
237
+ device = self.get_model().device
238
+ attention_mask = inputs.attention_mask.to(device)
239
+ input_ids = inputs.input_ids.to(device) # B x N
240
+ input_ids = torch.cat([input_ids, torch.tensor([[DEFAULT_IM_START_TOKEN_IDX]]).to(device)], dim=1)
241
+
242
+
243
+ text_embeds = self.get_model().embed_tokens(input_ids)
244
+ latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
245
+
246
+
247
+ if pixel_values is not None:
248
+ und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
249
+ pixel_values = pixel_values.type(self.visual.dtype)
250
+ und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
251
+ text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
252
+
253
+
254
+ text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
255
+ attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
256
+
257
+
258
+ outputs = self.model(
259
+ inputs_embeds=text_embeds,
260
+ attention_mask=attention_mask,
261
+ output_hidden_states=True,
262
+ return_dict=True,
263
+ )
264
+ hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
265
+ img_hidden_states = hidden_states
266
+ output_img = self.sample_images(img_hidden_states, scheduler)
267
+ output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
268
+
269
+ return output_img
270
+
271
+ def sample_images(
272
+ self,
273
+ img_hidden_states,
274
+ scheduler,
275
+ guidance_scale: float = 3.0,
276
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
277
+ num_inference_steps: int = 30,
278
+ num_images_per_prompt: int = 1,
279
+ return_tensor=False,
280
+ **kwargs,
281
+ ):
282
+
283
+ device = img_hidden_states.device
284
+ dtype = img_hidden_states.dtype
285
+
286
+
287
+ img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
288
+ img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
289
+
290
+ batch_size = img_hidden_states.shape[0]
291
+ latent_size = self.get_model().dit.config.input_size
292
+ latent_channels = self.get_model().dit.config.in_channels
293
+
294
+ latents = randn_tensor(
295
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
296
+ generator=generator,
297
+ device=device,
298
+ dtype=dtype,
299
+ )
300
+
301
+ # set step values
302
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
303
+ scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
304
+
305
+ # Repeat z_latents and conditions for each image per prompt
306
+ img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
307
+
308
+ for t in scheduler.timesteps:
309
+ latent_model_input = latents.repeat(2, 1, 1, 1)
310
+ if hasattr(scheduler, "scale_model_input"):
311
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
312
+
313
+ # predict noise model_output
314
+ noise_pred = self.get_model().dit(
315
+ x=latent_model_input,
316
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
317
+ z_latents=img_hidden_states_input,
318
+ )
319
+
320
+ # perform guidance
321
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
322
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
323
+
324
+ # compute previous image: x_t -> x_t-1
325
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
326
+
327
+ # samples = self.decode_latents(latents, return_tensor=return_tensor)
328
+ # breakpoint()
329
+ return latents
330
+
331
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
332
+ if isinstance(self.get_model().vae, AutoencoderKL):
333
+ latents = latents / self.get_model().vae.config.scaling_factor
334
+ if self.get_model().vae.config.shift_factor is not None:
335
+ latents = latents + self.get_model().vae.config.shift_factor
336
+ latents = latents.to(dtype=torch.float32)
337
+ samples = self.get_model().vae.decode(latents).sample
338
+ else:
339
+ samples = self.get_model().vae.decode(latents)
340
+ if normalize:
341
+ samples = (samples / 2 + 0.5).clamp(0, 1)
342
+ else:
343
+ samples = samples.clamp(-1, 1)
344
+ if return_tensor:
345
+ return samples
346
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
347
+ samples = numpy_to_pil(samples)
348
+ return samples
349
+
350
+ def prepare_and_encode_inputs(
351
+ self,
352
+ inputs: List[str | Image.Image],
353
+ tokenizer: AutoTokenizer,
354
+ do_classifier_free_guidance: bool = False,
355
+ ):
356
+ # pdb.set_trace()
357
+ device = self.get_model().device
358
+ dtype = self.get_model().dtype
359
+
360
+ has_image, has_text = False, False
361
+ text_prompt, image_prompt = "", []
362
+ img_processor = self.get_vision_tower().image_processor
363
+ negative_prompt = {}
364
+
365
+ for x in inputs:
366
+ if isinstance(x, str):
367
+ has_text = True
368
+ text_prompt += x
369
+ else:
370
+ has_image = True
371
+ text_prompt += DEFAULT_IMAGE_TOKEN
372
+ image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
373
+ if len(image_prompt) == 0:
374
+ image_prompt = None
375
+ else:
376
+ image_prompt = torch.cat(image_prompt)
377
+ image_prompt = image_prompt.type(dtype).to(device)
378
+
379
+ if has_image and not has_text:
380
+ prompt = self.encode_images(image_prompt)
381
+ if do_classifier_free_guidance:
382
+ key = "[NULL_IMAGE]"
383
+ if key not in negative_prompt:
384
+ negative_image = torch.zeros_like(image_prompt)
385
+ negative_prompt[key] = self.encode_images(negative_image)
386
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
387
+ else:
388
+ prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
389
+ if do_classifier_free_guidance:
390
+ key = ""
391
+ if key not in negative_prompt:
392
+ negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
393
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
394
+
395
+ gen_pooling = self.get_gen_pooling()
396
+ n_query = self.get_n_query()
397
+ num_img, _, c = prompt.shape
398
+ if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
399
+ stride = int(gen_pooling.split('_')[1])
400
+ sqrt_n = int(n_query**0.5)
401
+ prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
402
+ prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
403
+ prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
404
+ return prompt
405
+
406
+
407
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
408
+ inputs_embeds=None, **kwargs):
409
+ images = kwargs.pop("images", None)
410
+ image_sizes = kwargs.pop("image_sizes", None)
411
+ inputs = super().prepare_inputs_for_generation(
412
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
413
+ )
414
+ if images is not None:
415
+ inputs['images'] = images
416
+ if image_sizes is not None:
417
+ inputs['image_sizes'] = image_sizes
418
+ return inputs
419
+
420
+
421
+ AutoConfig.register("llava_qwen2", blip3oFastConfig)
422
+ AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM)