from typing import List, Optional import torch import torch.nn as nn from PIL import Image import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import Qwen2_5_VLConfig, Qwen2ForCausalLM, Qwen2Config, Qwen2Model from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import numpy_to_pil import numpy as np from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from tqdm import tqdm class blip3oFastConfig(Qwen2Config): model_type = "blip3o_fast_inference" class blip3oFastModel(LlavaMetaModel, Qwen2Model): config_class = blip3oFastConfig def __init__(self, config: Qwen2_5_VLConfig): super(blip3oFastModel, self).__init__(config) class blip3oFastForInferenceLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): config_class = blip3oFastConfig def __init__(self, config): super(blip3oFastForInferenceLM, self).__init__(config) config.model_type = "blip3o_qwen_inference" self.model = blip3oFastModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor: image_features = self.get_model().get_vision_tower()(pixel_values) image_features = self.get_model().mm_projector(image_features) return image_features @torch.no_grad() def generate_image( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, with_cfg: bool = False, max_var: Optional[float] = None, ): text_embeds = self.get_model().embed_tokens(input_ids) if pixel_values is not None: und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX) pixel_values = pixel_values.type(self.visual.dtype) und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :] outputs = self.model( inputs_embeds=text_embeds, attention_mask=attention_mask, output_hidden_states=False, return_dict=True, ) img_hidden_states = outputs.last_hidden_state output_img = self.sample_images(img_hidden_states, attention_mask, with_cfg) return output_img def sample_images( self, pred_latents, attention_mask, with_cfg: bool = False, guidance_scale: float = 3.0, num_inference_steps: int = 30, num_images_per_prompt: int = 1, return_tensor=False, **kwargs, ): device = pred_latents.device dtype = pred_latents.dtype if with_cfg: img_hidden_states_null = torch.zeros_like(pred_latents, device=device, dtype=dtype) pred_latents = torch.cat([img_hidden_states_null, pred_latents], 0) batch_size = pred_latents.shape[0] latent_size = self.get_model().dit.config.sample_size latent_channels = self.get_model().dit.config.in_channels latents = randn_tensor( shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size), generator=None, device=device, dtype=dtype, ) #sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) self.model.noise_scheduler.set_timesteps(num_inference_steps) for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images"): if with_cfg: latent_model_input = torch.cat([latents] * 2) latent_model_input = latent_model_input.to(pred_latents.dtype) else: latent_model_input = latents.to(pred_latents.dtype) if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"): latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t) noise_pred = self.model.dit( hidden_states=latent_model_input, encoder_hidden_states=self.model.diffusion_connector(pred_latents), timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device), # encoder_attention_mask=attention_mask encoder_attention_mask=None ).sample if with_cfg: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample samples = self.decode_latents(latents.to(self.model.vae.dtype), return_tensor=return_tensor) return samples @torch.no_grad() def decode_latents(self, latents, normalize=True, return_tensor=False): if self.model.vae is not None: latents = latents / self.model.vae.config.scaling_factor if "shift_factor" in self.model.vae.config and self.model.vae.config.shift_factor is not None: latents = latents + self.model.vae.config.shift_factor samples = self.model.vae.decode(latents).sample else: samples = latents if normalize: samples = (samples / 2 + 0.5).clamp(0, 1) else: samples = samples.clamp(-1, 1) if return_tensor: return samples samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() samples = numpy_to_pil(samples) return samples def prepare_and_encode_inputs( self, inputs: List[str | Image.Image], tokenizer: AutoTokenizer, do_classifier_free_guidance: bool = False, ): print("="*20, "prepare_and_encode_inputs", "="*20) # pdb.set_trace() device = self.get_model().device dtype = self.get_model().dtype has_image, has_text = False, False text_prompt, image_prompt = "", [] img_processor = self.get_vision_tower().image_processor negative_prompt = {} for x in inputs: if isinstance(x, str): has_text = True text_prompt += x else: has_image = True text_prompt += DEFAULT_IMAGE_TOKEN image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values']) # pdb.set_trace() if len(image_prompt) == 0: image_prompt = None else: image_prompt = torch.cat(image_prompt) image_prompt = image_prompt.type(dtype).to(device) if has_image and not has_text: prompt = self.encode_images(image_prompt) # pdb.set_trace() if do_classifier_free_guidance: key = "[NULL_IMAGE]" if key not in negative_prompt: negative_image = torch.zeros_like(image_prompt) negative_prompt[key] = self.encode_images(negative_image) prompt = torch.cat([prompt, negative_prompt[key]], dim=0) else: prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer) if do_classifier_free_guidance: key = "" if key not in negative_prompt: negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer) prompt = torch.cat([prompt, negative_prompt[key]], dim=0) gen_pooling = self.get_gen_pooling() n_query = self.get_n_query() num_img, _, c = prompt.shape if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling: stride = int(gen_pooling.split('_')[1]) sqrt_n = int(n_query**0.5) prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n) prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride) prompt = prompt.reshape(num_img, c, -1).permute(0,2,1) return prompt def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): print("="*20, "prepare_inputs_for_generation", "="*20) images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs['images'] = images if image_sizes is not None: inputs['image_sizes'] = image_sizes return inputs AutoConfig.register("blip3o_fast_inference", blip3oFastConfig) AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForInferenceLM)