Fahad-S commited on
Commit
4a21d2d
·
verified ·
1 Parent(s): a630e05

Upload teacher_code/blip3o_fast_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. teacher_code/blip3o_fast_inference.py +231 -0
teacher_code/blip3o_fast_inference.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import Qwen2_5_VLConfig, Qwen2ForCausalLM, Qwen2Config, Qwen2Model
8
+ from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN
9
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
10
+
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
13
+ import numpy as np
14
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
15
+ from tqdm import tqdm
16
+
17
+
18
+ class blip3oFastConfig(Qwen2Config):
19
+ model_type = "blip3o_fast_inference"
20
+
21
+
22
+ class blip3oFastModel(LlavaMetaModel, Qwen2Model):
23
+ config_class = blip3oFastConfig
24
+
25
+ def __init__(self, config: Qwen2_5_VLConfig):
26
+ super(blip3oFastModel, self).__init__(config)
27
+
28
+
29
+ class blip3oFastForInferenceLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
30
+ config_class = blip3oFastConfig
31
+
32
+ def __init__(self, config):
33
+ super(blip3oFastForInferenceLM, self).__init__(config)
34
+ config.model_type = "blip3o_qwen_inference"
35
+
36
+ self.model = blip3oFastModel(config)
37
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
38
+ # Initialize weights and apply final processing
39
+ self.post_init()
40
+
41
+ def get_model(self):
42
+ return self.model
43
+
44
+ def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor:
45
+ image_features = self.get_model().get_vision_tower()(pixel_values)
46
+ image_features = self.get_model().mm_projector(image_features)
47
+ return image_features
48
+
49
+ @torch.no_grad()
50
+ def generate_image(
51
+ self,
52
+ input_ids: Optional[torch.Tensor] = None,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ pixel_values: Optional[torch.Tensor] = None,
55
+ image_grid_thw: Optional[torch.Tensor] = None,
56
+ max_var: Optional[float] = None,
57
+ ):
58
+ N_QUERY = self.get_n_query()
59
+ print("N_QUERY: ", N_QUERY)
60
+ text_embeds = self.get_model().embed_tokens(input_ids)
61
+ latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
62
+
63
+
64
+ if pixel_values is not None:
65
+ und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
66
+ pixel_values = pixel_values.type(self.visual.dtype)
67
+ und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
68
+ text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
69
+
70
+
71
+ text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
72
+ attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
73
+ outputs = self.model(
74
+ inputs_embeds=text_embeds,
75
+ attention_mask=attention_mask,
76
+ output_hidden_states=True,
77
+ return_dict=True,
78
+ )
79
+ hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
80
+ img_hidden_states = hidden_states
81
+ output_img = self.sample_images(img_hidden_states)
82
+ return output_img
83
+
84
+ def sample_images(
85
+ self,
86
+ pred_latents,
87
+ guidance_scale: float = 3.0,
88
+ num_inference_steps: int = 30,
89
+ num_images_per_prompt: int = 1,
90
+ return_tensor=False,
91
+ **kwargs,
92
+ ):
93
+ device = pred_latents.device
94
+ dtype = pred_latents.dtype
95
+
96
+
97
+ img_hidden_states_null = torch.zeros_like(pred_latents, device=device, dtype=dtype)
98
+ pred_latents = torch.cat([img_hidden_states_null, pred_latents], 0)
99
+ batch_size = pred_latents.shape[0]
100
+ latent_size = self.get_n_query()
101
+ latent_channels = self.get_model().dit.config.in_channels
102
+
103
+ latents = randn_tensor(
104
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
105
+ generator=None,
106
+ device=device,
107
+ dtype=dtype,
108
+ )
109
+
110
+ # set step values
111
+ if isinstance(self.model.noise_scheduler, FlowMatchEulerDiscreteScheduler):
112
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
113
+ self.model.noise_scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
114
+ else:
115
+ self.model.noise_scheduler.set_timesteps(num_inference_steps)
116
+
117
+ for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images"):
118
+ latent_model_input = torch.cat([latents] * 2)
119
+ latent_model_input = latent_model_input.to(pred_latents.dtype)
120
+ if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"):
121
+ latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
122
+ #print("latent_model_input: ", latent_model_input.shape)
123
+ noise_pred = self.model.dit(
124
+ hidden_states=latent_model_input,
125
+ encoder_hidden_states=self.model.diffusion_connector(pred_latents),
126
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device),
127
+ encoder_attention_mask=None
128
+ ).sample
129
+
130
+ noise_pred_uncond, noise_pred= noise_pred.chunk(2)
131
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
132
+ latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
133
+
134
+ samples = self.decode_latents(latents.to(self.model.vae.dtype) if self.model.vae is not None else latents, return_tensor=return_tensor)
135
+ return samples
136
+
137
+ @torch.no_grad()
138
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
139
+ if self.model.vae is not None:
140
+ latents = latents / self.model.vae.config.scaling_factor
141
+ if "shift_factor" in self.model.vae.config and self.model.vae.config.shift_factor is not None:
142
+ latents = latents + self.model.vae.config.shift_factor
143
+ samples = self.model.vae.decode(latents).sample
144
+ else:
145
+ samples = latents
146
+ if normalize:
147
+ samples = (samples / 2 + 0.5).clamp(0, 1)
148
+ else:
149
+ samples = samples.clamp(-1, 1)
150
+ if return_tensor:
151
+ return samples
152
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
153
+ samples = numpy_to_pil(samples)
154
+ return samples
155
+
156
+ def prepare_and_encode_inputs(
157
+ self,
158
+ inputs: List[str | Image.Image],
159
+ tokenizer: AutoTokenizer,
160
+ do_classifier_free_guidance: bool = False,
161
+ ):
162
+ print("="*20, "prepare_and_encode_inputs", "="*20)
163
+ # pdb.set_trace()
164
+ device = self.get_model().device
165
+ dtype = self.get_model().dtype
166
+
167
+ has_image, has_text = False, False
168
+ text_prompt, image_prompt = "", []
169
+ img_processor = self.get_vision_tower().image_processor
170
+ negative_prompt = {}
171
+
172
+ for x in inputs:
173
+ if isinstance(x, str):
174
+ has_text = True
175
+ text_prompt += x
176
+ else:
177
+ has_image = True
178
+ text_prompt += DEFAULT_IMAGE_TOKEN
179
+ image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
180
+ # pdb.set_trace()
181
+ if len(image_prompt) == 0:
182
+ image_prompt = None
183
+ else:
184
+ image_prompt = torch.cat(image_prompt)
185
+ image_prompt = image_prompt.type(dtype).to(device)
186
+
187
+ if has_image and not has_text:
188
+ prompt = self.encode_images(image_prompt)
189
+ # pdb.set_trace()
190
+ if do_classifier_free_guidance:
191
+ key = "[NULL_IMAGE]"
192
+ if key not in negative_prompt:
193
+ negative_image = torch.zeros_like(image_prompt)
194
+ negative_prompt[key] = self.encode_images(negative_image)
195
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
196
+ else:
197
+ prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
198
+ if do_classifier_free_guidance:
199
+ key = ""
200
+ if key not in negative_prompt:
201
+ negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
202
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
203
+
204
+ gen_pooling = self.get_gen_pooling()
205
+ n_query = self.get_n_query()
206
+ num_img, _, c = prompt.shape
207
+ if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
208
+ stride = int(gen_pooling.split('_')[1])
209
+ sqrt_n = int(n_query**0.5)
210
+ prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
211
+ prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
212
+ prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
213
+ return prompt
214
+
215
+
216
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
217
+ inputs_embeds=None, **kwargs):
218
+ print("="*20, "prepare_inputs_for_generation", "="*20)
219
+ images = kwargs.pop("images", None)
220
+ image_sizes = kwargs.pop("image_sizes", None)
221
+ inputs = super().prepare_inputs_for_generation(
222
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
223
+ )
224
+ if images is not None:
225
+ inputs['images'] = images
226
+ if image_sizes is not None:
227
+ inputs['image_sizes'] = image_sizes
228
+ return inputs
229
+
230
+ AutoConfig.register("blip3o_fast_inference", blip3oFastConfig)
231
+ AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForInferenceLM)