FastVLM_SANA / noqueries_code /blip3o_fast_inference.py
Fahad-S's picture
Upload noqueries_code/blip3o_fast_inference.py with huggingface_hub
fc156dd verified
from typing import List, Optional
import torch
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import Qwen2_5_VLConfig, Qwen2ForCausalLM, Qwen2Config, Qwen2Model
from blip3o.constants import UND_IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import numpy_to_pil
import numpy as np
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
class blip3oFastConfig(Qwen2Config):
model_type = "blip3o_fast_inference"
class blip3oFastModel(LlavaMetaModel, Qwen2Model):
config_class = blip3oFastConfig
def __init__(self, config: Qwen2_5_VLConfig):
super(blip3oFastModel, self).__init__(config)
class blip3oFastForInferenceLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
config_class = blip3oFastConfig
def __init__(self, config):
super(blip3oFastForInferenceLM, self).__init__(config)
config.model_type = "blip3o_qwen_inference"
self.model = blip3oFastModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def visual(self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None) -> torch.Tensor:
image_features = self.get_model().get_vision_tower()(pixel_values)
image_features = self.get_model().mm_projector(image_features)
return image_features
@torch.no_grad()
def generate_image(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
with_cfg: bool = False,
max_var: Optional[float] = None,
):
text_embeds = self.get_model().embed_tokens(input_ids)
if pixel_values is not None:
und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
pixel_values = pixel_values.type(self.visual.dtype)
und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
outputs = self.model(
inputs_embeds=text_embeds,
attention_mask=attention_mask,
output_hidden_states=False,
return_dict=True,
)
img_hidden_states = outputs.last_hidden_state
output_img = self.sample_images(img_hidden_states, attention_mask, with_cfg)
return output_img
def sample_images(
self,
pred_latents,
attention_mask,
with_cfg: bool = False,
guidance_scale: float = 3.0,
num_inference_steps: int = 30,
num_images_per_prompt: int = 1,
return_tensor=False,
**kwargs,
):
device = pred_latents.device
dtype = pred_latents.dtype
if with_cfg:
img_hidden_states_null = torch.zeros_like(pred_latents, device=device, dtype=dtype)
pred_latents = torch.cat([img_hidden_states_null, pred_latents], 0)
batch_size = pred_latents.shape[0]
latent_size = self.get_model().dit.config.sample_size
latent_channels = self.get_model().dit.config.in_channels
latents = randn_tensor(
shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
generator=None,
device=device,
dtype=dtype,
)
#sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
self.model.noise_scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images"):
if with_cfg:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = latent_model_input.to(pred_latents.dtype)
else:
latent_model_input = latents.to(pred_latents.dtype)
if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"):
latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.model.dit(
hidden_states=latent_model_input,
encoder_hidden_states=self.model.diffusion_connector(pred_latents),
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device),
# encoder_attention_mask=attention_mask
encoder_attention_mask=None
).sample
if with_cfg:
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
samples = self.decode_latents(latents.to(self.model.vae.dtype), return_tensor=return_tensor)
return samples
@torch.no_grad()
def decode_latents(self, latents, normalize=True, return_tensor=False):
if self.model.vae is not None:
latents = latents / self.model.vae.config.scaling_factor
if "shift_factor" in self.model.vae.config and self.model.vae.config.shift_factor is not None:
latents = latents + self.model.vae.config.shift_factor
samples = self.model.vae.decode(latents).sample
else:
samples = latents
if normalize:
samples = (samples / 2 + 0.5).clamp(0, 1)
else:
samples = samples.clamp(-1, 1)
if return_tensor:
return samples
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
samples = numpy_to_pil(samples)
return samples
def prepare_and_encode_inputs(
self,
inputs: List[str | Image.Image],
tokenizer: AutoTokenizer,
do_classifier_free_guidance: bool = False,
):
print("="*20, "prepare_and_encode_inputs", "="*20)
# pdb.set_trace()
device = self.get_model().device
dtype = self.get_model().dtype
has_image, has_text = False, False
text_prompt, image_prompt = "", []
img_processor = self.get_vision_tower().image_processor
negative_prompt = {}
for x in inputs:
if isinstance(x, str):
has_text = True
text_prompt += x
else:
has_image = True
text_prompt += DEFAULT_IMAGE_TOKEN
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
# pdb.set_trace()
if len(image_prompt) == 0:
image_prompt = None
else:
image_prompt = torch.cat(image_prompt)
image_prompt = image_prompt.type(dtype).to(device)
if has_image and not has_text:
prompt = self.encode_images(image_prompt)
# pdb.set_trace()
if do_classifier_free_guidance:
key = "[NULL_IMAGE]"
if key not in negative_prompt:
negative_image = torch.zeros_like(image_prompt)
negative_prompt[key] = self.encode_images(negative_image)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
else:
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
if do_classifier_free_guidance:
key = ""
if key not in negative_prompt:
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
gen_pooling = self.get_gen_pooling()
n_query = self.get_n_query()
num_img, _, c = prompt.shape
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
stride = int(gen_pooling.split('_')[1])
sqrt_n = int(n_query**0.5)
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
return prompt
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
print("="*20, "prepare_inputs_for_generation", "="*20)
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("blip3o_fast_inference", blip3oFastConfig)
AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForInferenceLM)