| from typing import List, Optional |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| import torch.nn.functional as F |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| from transformers import Qwen2_5_VLConfig, Qwen2ForCausalLM, Qwen2Config, Qwen2Model |
| from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN |
| from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
|
|
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.pipelines.pipeline_utils import numpy_to_pil |
| import numpy as np |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from tqdm import tqdm |
|
|
|
|
| class blip3oFastConfig(Qwen2Config): |
| model_type = "blip3o_fast_inference" |
|
|
|
|
| class blip3oFastModel(LlavaMetaModel, Qwen2Model): |
| config_class = blip3oFastConfig |
|
|
| def __init__(self, config: Qwen2_5_VLConfig): |
| super(blip3oFastModel, self).__init__(config) |
|
|
|
|
| class blip3oFastForInferenceLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): |
| config_class = blip3oFastConfig |
|
|
| def __init__(self, config): |
| super(blip3oFastForInferenceLM, self).__init__(config) |
| config.model_type = "blip3o_qwen_inference" |
|
|
| self.model = blip3oFastModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| self.post_init() |
|
|
| def get_model(self): |
| return self.model |
|
|
| def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor: |
| image_features = self.get_model().get_vision_tower()(pixel_values) |
| image_features = self.get_model().mm_projector(image_features) |
| return image_features |
| |
| @torch.no_grad() |
| def generate_image( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| max_var: Optional[float] = None, |
| ): |
| N_QUERY = self.get_n_query() |
| print("N_QUERY: ", N_QUERY) |
| text_embeds = self.get_model().embed_tokens(input_ids) |
| latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1) |
|
|
|
|
| if pixel_values is not None: |
| und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX) |
| pixel_values = pixel_values.type(self.visual.dtype) |
| und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :] |
|
|
|
|
| text_embeds = torch.cat([text_embeds, latent_queries], dim=1) |
| attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1) |
| outputs = self.model( |
| inputs_embeds=text_embeds, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:] |
| img_hidden_states = hidden_states |
| output_img = self.sample_images(img_hidden_states) |
| return output_img |
|
|
| def sample_images( |
| self, |
| pred_latents, |
| guidance_scale: float = 3.0, |
| num_inference_steps: int = 30, |
| num_images_per_prompt: int = 1, |
| return_tensor=False, |
| **kwargs, |
| ): |
| device = pred_latents.device |
| dtype = pred_latents.dtype |
|
|
| |
| img_hidden_states_null = torch.zeros_like(pred_latents, device=device, dtype=dtype) |
| pred_latents = torch.cat([img_hidden_states_null, pred_latents], 0) |
| batch_size = pred_latents.shape[0] |
| latent_size = self.get_n_query() |
| latent_channels = self.get_model().dit.config.in_channels |
|
|
| latents = randn_tensor( |
| shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size), |
| generator=None, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| |
| if isinstance(self.model.noise_scheduler, FlowMatchEulerDiscreteScheduler): |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) |
| self.model.noise_scheduler.set_timesteps(num_inference_steps, sigmas=sigmas) |
| else: |
| self.model.noise_scheduler.set_timesteps(num_inference_steps) |
|
|
| for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images"): |
| latent_model_input = torch.cat([latents] * 2) |
| latent_model_input = latent_model_input.to(pred_latents.dtype) |
| if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"): |
| latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t) |
| |
| noise_pred = self.model.dit( |
| hidden_states=latent_model_input, |
| encoder_hidden_states=self.model.diffusion_connector(pred_latents), |
| timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device), |
| encoder_attention_mask=None |
| ).sample |
|
|
| noise_pred_uncond, noise_pred= noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) |
| latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
| samples = self.decode_latents(latents.to(self.model.vae.dtype) if self.model.vae is not None else latents, return_tensor=return_tensor) |
| return samples |
|
|
| @torch.no_grad() |
| def decode_latents(self, latents, normalize=True, return_tensor=False): |
| if self.model.vae is not None: |
| latents = latents / self.model.vae.config.scaling_factor |
| if "shift_factor" in self.model.vae.config and self.model.vae.config.shift_factor is not None: |
| latents = latents + self.model.vae.config.shift_factor |
| samples = self.model.vae.decode(latents).sample |
| else: |
| samples = latents |
| if normalize: |
| samples = (samples / 2 + 0.5).clamp(0, 1) |
| else: |
| samples = samples.clamp(-1, 1) |
| if return_tensor: |
| return samples |
| samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() |
| samples = numpy_to_pil(samples) |
| return samples |
| |
| def prepare_and_encode_inputs( |
| self, |
| inputs: List[str | Image.Image], |
| tokenizer: AutoTokenizer, |
| do_classifier_free_guidance: bool = False, |
| ): |
| print("="*20, "prepare_and_encode_inputs", "="*20) |
| |
| device = self.get_model().device |
| dtype = self.get_model().dtype |
|
|
| has_image, has_text = False, False |
| text_prompt, image_prompt = "", [] |
| img_processor = self.get_vision_tower().image_processor |
| negative_prompt = {} |
|
|
| for x in inputs: |
| if isinstance(x, str): |
| has_text = True |
| text_prompt += x |
| else: |
| has_image = True |
| text_prompt += DEFAULT_IMAGE_TOKEN |
| image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values']) |
| |
| if len(image_prompt) == 0: |
| image_prompt = None |
| else: |
| image_prompt = torch.cat(image_prompt) |
| image_prompt = image_prompt.type(dtype).to(device) |
|
|
| if has_image and not has_text: |
| prompt = self.encode_images(image_prompt) |
| |
| if do_classifier_free_guidance: |
| key = "[NULL_IMAGE]" |
| if key not in negative_prompt: |
| negative_image = torch.zeros_like(image_prompt) |
| negative_prompt[key] = self.encode_images(negative_image) |
| prompt = torch.cat([prompt, negative_prompt[key]], dim=0) |
| else: |
| prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer) |
| if do_classifier_free_guidance: |
| key = "" |
| if key not in negative_prompt: |
| negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer) |
| prompt = torch.cat([prompt, negative_prompt[key]], dim=0) |
| |
| gen_pooling = self.get_gen_pooling() |
| n_query = self.get_n_query() |
| num_img, _, c = prompt.shape |
| if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling: |
| stride = int(gen_pooling.split('_')[1]) |
| sqrt_n = int(n_query**0.5) |
| prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n) |
| prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride) |
| prompt = prompt.reshape(num_img, c, -1).permute(0,2,1) |
| return prompt |
|
|
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
| inputs_embeds=None, **kwargs): |
| print("="*20, "prepare_inputs_for_generation", "="*20) |
| images = kwargs.pop("images", None) |
| image_sizes = kwargs.pop("image_sizes", None) |
| inputs = super().prepare_inputs_for_generation( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| ) |
| if images is not None: |
| inputs['images'] = images |
| if image_sizes is not None: |
| inputs['image_sizes'] = image_sizes |
| return inputs |
|
|
| AutoConfig.register("blip3o_fast_inference", blip3oFastConfig) |
| AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForInferenceLM) |
|
|