Fahad-S commited on
Commit
c83c97e
·
verified ·
1 Parent(s): e39da6a

Upload A_MobileO_With_Edit/blip3o_fast_inference.py with huggingface_hub

Browse files
A_MobileO_With_Edit/blip3o_fast_inference.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import Qwen2_5_VLConfig, Qwen2ForCausalLM, Qwen2Config, Qwen2Model
8
+ from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN
9
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
10
+
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
13
+ import numpy as np
14
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
15
+ from tqdm import tqdm
16
+
17
+
18
+ class blip3oFastConfig(Qwen2Config):
19
+ model_type = "blip3o_fast_inference"
20
+
21
+
22
+ class blip3oFastModel(LlavaMetaModel, Qwen2Model):
23
+ config_class = blip3oFastConfig
24
+
25
+ def __init__(self, config: Qwen2_5_VLConfig):
26
+ super(blip3oFastModel, self).__init__(config)
27
+
28
+
29
+ class blip3oFastForInferenceLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
30
+ config_class = blip3oFastConfig
31
+
32
+ def __init__(self, config):
33
+ super(blip3oFastForInferenceLM, self).__init__(config)
34
+ config.model_type = "blip3o_qwen_inference"
35
+
36
+ self.model = blip3oFastModel(config)
37
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
38
+ # Initialize weights and apply final processing
39
+ self.post_init()
40
+
41
+ def get_model(self):
42
+ return self.model
43
+
44
+ def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor:
45
+ image_features = self.get_model().get_vision_tower()(pixel_values)
46
+ image_features = self.get_model().mm_projector(image_features)
47
+ return image_features
48
+
49
+ @torch.no_grad()
50
+ def generate_image(
51
+ self,
52
+ input_ids: Optional[torch.Tensor] = None,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ pixel_values: Optional[torch.Tensor] = None,
55
+ image_grid_thw: Optional[torch.Tensor] = None,
56
+ with_cfg: bool = True,
57
+ max_var: Optional[float] = None,
58
+ num_inference_steps: int = 20,
59
+ ):
60
+ text_embeds = self.get_model().embed_tokens(input_ids)
61
+
62
+
63
+ if pixel_values is not None:
64
+ und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
65
+ pixel_values = pixel_values.type(self.visual.dtype)
66
+ und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
67
+ text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
68
+
69
+ outputs = self.model(
70
+ inputs_embeds=text_embeds,
71
+ attention_mask=attention_mask,
72
+ output_hidden_states=True,
73
+ return_dict=True,
74
+ )
75
+ img_hidden_states = outputs.hidden_states
76
+ output_img = self.sample_images(img_hidden_states, attention_mask, with_cfg, num_inference_steps=num_inference_steps)
77
+ return output_img
78
+ def sample_images(
79
+ self,
80
+ pred_latents, # Tuple/list of hidden states from all layers
81
+ attention_mask,
82
+ with_cfg: bool = True,
83
+ guidance_scale: float = 1.2,
84
+ num_inference_steps: int = 20,
85
+ num_images_per_prompt: int = 1,
86
+ return_tensor=False,
87
+ with_tqdm: bool = True,
88
+ **kwargs,
89
+ ):
90
+ # Get device and dtype from first element of tuple
91
+ device = pred_latents[0].device
92
+ dtype = pred_latents[0].dtype
93
+
94
+ # ✅ Store original batch size BEFORE any CFG operations
95
+ batch_size = pred_latents[0].shape[0]
96
+
97
+ latent_size = self.get_model().dit.config.sample_size
98
+ latent_channels = self.get_model().dit.config.in_channels
99
+
100
+ # ================================================================
101
+ # 1. CFG Preparation (same as sample_images_with_sam3)
102
+ # ================================================================
103
+ if with_cfg:
104
+ pred_latents_cfg = tuple(
105
+ torch.cat([torch.zeros_like(layer), layer], dim=0)
106
+ for layer in pred_latents
107
+ )
108
+ else:
109
+ pred_latents_cfg = pred_latents
110
+
111
+ # ================================================================
112
+ # 2. Process through connector
113
+ # ================================================================
114
+ encoder_hidden_states = self.model.diffusion_connector(pred_latents_cfg)
115
+ # Shape: [B, N, hidden_dim] or [2*B, N, hidden_dim] if with_cfg
116
+
117
+ # ================================================================
118
+ # 3. Initialize Latents (use original batch_size, not doubled)
119
+ # ================================================================
120
+ latents = randn_tensor(
121
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
122
+ generator=None,
123
+ device=device,
124
+ dtype=dtype,
125
+ )
126
+
127
+ # ================================================================
128
+ # 4. Denoising Loop
129
+ # ================================================================
130
+ self.model.noise_scheduler.set_timesteps(num_inference_steps)
131
+
132
+ iterator = tqdm(self.model.noise_scheduler.timesteps,
133
+ desc="Sampling") if with_tqdm else self.model.noise_scheduler.timesteps
134
+
135
+ for t in iterator:
136
+ # Prepare model input
137
+ if with_cfg:
138
+ latent_model_input = torch.cat([latents] * 2)
139
+ else:
140
+ latent_model_input = latents
141
+
142
+ latent_model_input = latent_model_input.to(dtype)
143
+
144
+ # Scale model input if needed
145
+ if hasattr(self.model.noise_scheduler, "scale_model_input"):
146
+ latent_model_input = self.model.noise_scheduler.scale_model_input(
147
+ latent_model_input, t
148
+ )
149
+
150
+ # DiT forward
151
+ noise_pred = self.model.dit(
152
+ hidden_states=latent_model_input,
153
+ encoder_hidden_states=encoder_hidden_states,
154
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(device),
155
+ encoder_attention_mask=None
156
+ ).sample
157
+
158
+ # Apply classifier-free guidance
159
+ if with_cfg:
160
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
161
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
162
+
163
+ # Perform denoising step
164
+ latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
165
+
166
+ # ================================================================
167
+ # 5. Decode
168
+ # ================================================================
169
+ samples = self.decode_latents(latents.to(self.model.vae.dtype), return_tensor=return_tensor)
170
+ return samples
171
+ def sample_images_(
172
+ self,
173
+ pred_latents, # Tuple/list of hidden states from all layers
174
+ attention_mask,
175
+ with_cfg: bool = False,
176
+ guidance_scale: float = 3.0,
177
+ num_inference_steps: int = 30,
178
+ num_images_per_prompt: int = 1,
179
+ return_tensor=False,
180
+ **kwargs,
181
+ ):
182
+ # Get device and dtype from first element of tuple
183
+ device = pred_latents[0].device
184
+ dtype = pred_latents[0].dtype
185
+
186
+ # Compute conditioning ONCE before the loop (not inside!)
187
+ encoder_hidden_states = self.model.diffusion_connector(pred_latents)
188
+ # Shape: [B, N, 2304]
189
+
190
+ # Handle classifier-free guidance
191
+ if with_cfg:
192
+ # Create null conditioning (zeros) for CFG
193
+ encoder_hidden_states_null = torch.zeros_like(encoder_hidden_states)
194
+ # Concatenate: [null_batch, text_batch]
195
+ encoder_hidden_states = torch.cat([encoder_hidden_states_null, encoder_hidden_states], dim=0)
196
+
197
+ # Get batch size from processed encoder hidden states
198
+ batch_size = encoder_hidden_states.shape[0]
199
+ latent_size = self.get_model().dit.config.sample_size
200
+ latent_channels = self.get_model().dit.config.in_channels
201
+
202
+ # Initialize random noise latents
203
+ latents = randn_tensor(
204
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
205
+ generator=None,
206
+ device=device,
207
+ dtype=dtype,
208
+ )
209
+
210
+ def sample_images_original(
211
+ self,
212
+ pred_latents, # Tuple/list of hidden states from all layers
213
+ attention_mask,
214
+ with_cfg: bool = False,
215
+ guidance_scale: float = 1.2,
216
+ num_inference_steps: int = 30,
217
+ num_images_per_prompt: int = 1,
218
+ return_tensor=False,
219
+ **kwargs,
220
+ ):
221
+ # Get device and dtype from first element of tuple
222
+ device = pred_latents[0].device
223
+ dtype = pred_latents[0].dtype # ✅ Store dtype here
224
+
225
+ encoder_hidden_states = self.model.diffusion_connector(pred_latents)
226
+ # Shape: [B, N, 2304]
227
+
228
+ # Handle classifier-free guidance
229
+ if with_cfg:
230
+ # Create null conditioning (zeros) for CFG
231
+ encoder_hidden_states_null = torch.zeros_like(encoder_hidden_states)
232
+ # Concatenate: [null_batch, text_batch]
233
+ encoder_hidden_states = torch.cat([encoder_hidden_states_null, encoder_hidden_states], dim=0)
234
+
235
+ # Get batch size from processed encoder hidden states
236
+ batch_size = encoder_hidden_states.shape[0]
237
+ latent_size = self.get_model().dit.config.sample_size
238
+ latent_channels = self.get_model().dit.config.in_channels
239
+
240
+ # Initialize random noise latents
241
+ latents = randn_tensor(
242
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
243
+ generator=None,
244
+ device=device,
245
+ dtype=dtype,
246
+ )
247
+
248
+ # Set up timesteps for denoising
249
+ self.model.noise_scheduler.set_timesteps(num_inference_steps)
250
+
251
+ # Denoising loop
252
+ for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images"):
253
+ # Prepare model input
254
+ if with_cfg:
255
+ # Duplicate latents for CFG (unconditional + conditional)
256
+ latent_model_input = torch.cat([latents] * 2)
257
+ else:
258
+ latent_model_input = latents
259
+
260
+ latent_model_input = latent_model_input.to(dtype)
261
+
262
+ # Scale model input if needed (for some schedulers)
263
+ if hasattr(self.model.noise_scheduler, "scale_model_input"):
264
+ latent_model_input = self.model.noise_scheduler.scale_model_input(
265
+ latent_model_input, t
266
+ )
267
+
268
+ # Predict noise
269
+ noise_pred = self.model.dit(
270
+ hidden_states=latent_model_input,
271
+ encoder_hidden_states=encoder_hidden_states,
272
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(device),
273
+ encoder_attention_mask=None
274
+ ).sample
275
+
276
+ # Apply classifier-free guidance
277
+ if with_cfg:
278
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
279
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
280
+
281
+ # Perform denoising step
282
+ latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
283
+
284
+ # Decode latents to images
285
+ samples = self.decode_latents(latents.to(self.model.vae.dtype), return_tensor=return_tensor)
286
+ return samples
287
+
288
+
289
+ @torch.no_grad()
290
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
291
+ if self.model.vae is not None:
292
+ latents = latents / self.model.vae.config.scaling_factor
293
+ if "shift_factor" in self.model.vae.config and self.model.vae.config.shift_factor is not None:
294
+ latents = latents + self.model.vae.config.shift_factor
295
+ samples = self.model.vae.decode(latents).sample
296
+ else:
297
+ samples = latents
298
+ if normalize:
299
+ samples = (samples / 2 + 0.5).clamp(0, 1)
300
+ else:
301
+ samples = samples.clamp(-1, 1)
302
+ if return_tensor:
303
+ return samples
304
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
305
+ samples = numpy_to_pil(samples)
306
+ return samples
307
+
308
+ def prepare_and_encode_inputs(
309
+ self,
310
+ inputs: List[str | Image.Image],
311
+ tokenizer: AutoTokenizer,
312
+ do_classifier_free_guidance: bool = False,
313
+ ):
314
+ print("="*20, "prepare_and_encode_inputs", "="*20)
315
+ # pdb.set_trace()
316
+ device = self.get_model().device
317
+ dtype = self.get_model().dtype
318
+
319
+ has_image, has_text = False, False
320
+ text_prompt, image_prompt = "", []
321
+ img_processor = self.get_vision_tower().image_processor
322
+ negative_prompt = {}
323
+
324
+ for x in inputs:
325
+ if isinstance(x, str):
326
+ has_text = True
327
+ text_prompt += x
328
+ else:
329
+ has_image = True
330
+ text_prompt += DEFAULT_IMAGE_TOKEN
331
+ image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
332
+ # pdb.set_trace()
333
+ if len(image_prompt) == 0:
334
+ image_prompt = None
335
+ else:
336
+ image_prompt = torch.cat(image_prompt)
337
+ image_prompt = image_prompt.type(dtype).to(device)
338
+
339
+ if has_image and not has_text:
340
+ prompt = self.encode_images(image_prompt)
341
+ # pdb.set_trace()
342
+ if do_classifier_free_guidance:
343
+ key = "[NULL_IMAGE]"
344
+ if key not in negative_prompt:
345
+ negative_image = torch.zeros_like(image_prompt)
346
+ negative_prompt[key] = self.encode_images(negative_image)
347
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
348
+ else:
349
+ prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
350
+ if do_classifier_free_guidance:
351
+ key = ""
352
+ if key not in negative_prompt:
353
+ negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
354
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
355
+
356
+ gen_pooling = self.get_gen_pooling()
357
+ n_query = self.get_n_query()
358
+ num_img, _, c = prompt.shape
359
+ if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
360
+ stride = int(gen_pooling.split('_')[1])
361
+ sqrt_n = int(n_query**0.5)
362
+ prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
363
+ prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
364
+ prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
365
+ return prompt
366
+
367
+
368
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
369
+ inputs_embeds=None, **kwargs):
370
+ print("="*20, "prepare_inputs_for_generation", "="*20)
371
+ images = kwargs.pop("images", None)
372
+ image_sizes = kwargs.pop("image_sizes", None)
373
+ inputs = super().prepare_inputs_for_generation(
374
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
375
+ )
376
+ if images is not None:
377
+ inputs['images'] = images
378
+ if image_sizes is not None:
379
+ inputs['image_sizes'] = image_sizes
380
+ return inputs
381
+
382
+ AutoConfig.register("blip3o_fast_inference", blip3oFastConfig)
383
+ AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForInferenceLM)