| import sys |
| import os |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Dict, Any, Union |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from transformers.utils import ModelOutput |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers import AutoConfig |
| from safetensors.torch import load_file |
| import torchvision.transforms as transforms |
|
|
| from .De_DiffusionV2_Image import SDModel |
| from .sd_config import SDConfig |
| import argparse |
|
|
| def handle_module_prefix(state_dict): |
| """Handle 'module.' prefix in state dict keys.""" |
| if any(k.startswith('module.') for k in state_dict.keys()): |
| return {k.replace('module.', ''): v for k, v in state_dict.items()} |
| return state_dict |
|
|
| def create_model_args(args): |
| """Create model arguments needed by SDModel.""" |
| model_args = argparse.Namespace() |
| model_args.use_text_encoder = args.use_text_encoder |
| model_args.batch_size = args.batch_size |
| model_args.eval_batch_size = args.batch_size |
| model_args.distributed_strategy = 'none' |
| model_args.fp32 = args.fp32 |
| model_args.learnable_token_length = args.learnable_token_length |
| model_args.num_inference_steps = args.num_inference_steps |
| model_args.image_size = args.image_size |
| model_args.guidance_scale = args.guidance_scale |
| model_args.unfreeze_florence2_all = False |
| model_args.unfreeze_florence2_language_model = False |
| model_args.unfreeze_florence2_language_model_decoder = False |
| return model_args |
|
|
| def load_model_checkpoint(model, model_path, device): |
| """Load model checkpoint.""" |
| checkpoint = torch.load(model_path, map_location="cpu") |
| |
| |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| state_dict = handle_module_prefix(state_dict) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |
|
|
| def initialize_diffusion_model(args): |
| """Initialize the diffusion model.""" |
| config = SDConfig() |
| diffusion_model_args = create_model_args(args) |
| diffusion_model = SDModel(config, diffusion_model_args) |
| _dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16 |
|
|
| |
| if hasattr(diffusion_model, 'vae'): |
| del diffusion_model.vae |
| if hasattr(diffusion_model, 'unet'): |
| del diffusion_model.unet |
| |
| |
| torch.cuda.empty_cache() |
|
|
| diffusion_model = diffusion_model.to(_dtype) |
| |
| |
| for param in diffusion_model.language_proj.parameters(): |
| param.requires_grad = False |
| diffusion_model.query_embed.requires_grad = False |
|
|
| return diffusion_model |
|
|
| class VLV_Config(PretrainedConfig): |
| model_type = "VLV_decoder" |
|
|
| def __init__(self, hidden_size=128, **kwargs): |
| super().__init__(**kwargs) |
| pass |
|
|
| class VLV_MODEL(PreTrainedModel): |
| config_class = VLV_Config |
| model_type = "VLV_decoder" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| """Load the CLIPDecoder model.""" |
| |
| device = "cuda" |
| de_diffusion_model = initialize_diffusion_model(config) |
| clip_decoder_model = CLIPDecoder( |
| language_model=config.qwen_model, |
| qwen2_config=config.qwen2_config, |
| VLV_model=de_diffusion_model, |
| device=device, |
| bf16=config.mixed_precision |
| ) |
| |
| |
| |
| |
| |
| clip_decoder_model.eval() |
|
|
| self.model = clip_decoder_model |
| self.max_new_tokens = config.max_length |
| self.num_beams = config.num_beams |
| self.transform = self.get_transform(config.image_size) |
|
|
| def get_transform(self, image_size): |
| """Transformation pipeline for input images.""" |
| return transforms.Compose([ |
| transforms.Resize(image_size), |
| transforms.CenterCrop((image_size, image_size)), |
| transforms.PILToTensor(), |
| ]) |
|
|
| def forward(self, valid_images, max_length): |
| valid_images = [self.transform(img) for img in valid_images] |
| if hasattr(self.model, 'module'): |
| outputs = self.model.module.generate( |
| valid_images, |
| max_new_tokens=max_length, |
| num_beams=self.num_beams, |
| early_stopping=True |
| ) |
| else: |
| outputs = self.model.generate( |
| valid_images, |
| max_new_tokens=max_length, |
| num_beams=self.num_beams, |
| early_stopping=True |
| ) |
| return outputs |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, input_dim, output_dim): |
| super(MLP, self).__init__() |
| self.layers = nn.Sequential( |
| nn.Linear(input_dim, output_dim), |
| nn.GELU(), |
| nn.Linear(output_dim, output_dim), |
| ) |
| |
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| @dataclass |
| class CLIPDecoderOutput(ModelOutput): |
| """ |
| Output class for the CLIP Decoder model. |
| """ |
| last_hidden_state: Optional[torch.FloatTensor] = None |
| generated_ids: Optional[torch.LongTensor] = None |
| generated_text: Optional[list] = None |
|
|
|
|
| class CLIPDecoder(nn.Module): |
|
|
|
|
| def __init__( |
| self, |
| language_model: str, |
| qwen2_config: dict, |
| VLV_model: SDModel, |
| device: torch.device, |
| bf16: bool = False, |
| args: argparse.Namespace = None |
| ): |
| """ |
| Initialize the CLIP Decoder model. |
| |
| Args: |
| language_model: Path to the language model |
| VLV_model: The VLV model instance |
| device: The device to run the model on |
| bf16: Whether to use bfloat16 precision |
| """ |
| super(CLIPDecoder, self).__init__() |
|
|
| self._dtype = torch.bfloat16 if bf16 =="bf16" else torch.float32 |
| self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model) |
| self.qwen2_config = AutoConfig.from_pretrained(language_model) |
| self.qwen2_model = AutoModelForCausalLM.from_config(self.qwen2_config) |
| self.qwen2_model.save_pretrained("QWEN_new_2", safe_serialization=False) |
| |
| |
| self.VLV_model = VLV_model |
| self.device = device |
| self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size) |
| self.ignore_token_id = -100 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def get_conditional_context(self, images, batch_size): |
| """ |
| Get conditional context from images using the diffusion model. |
| |
| Args: |
| images: Input images |
| batch_size: Batch size |
| |
| Returns: |
| Decoder hidden states from the diffusion model |
| """ |
| prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size |
| inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(torch.float32) |
|
|
| self.VLV_model = self.VLV_model.to(inputs["input_ids"].device) |
| self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device) |
| self.mlp = self.mlp.to(inputs["input_ids"].device) |
| self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device) |
| |
| if inputs["input_ids"] is not None: |
| inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device) |
| |
| if inputs["pixel_values"] is not None: |
| image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device) |
| inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features( |
| image_features, inputs_embeds |
| ) |
| |
| if inputs_embeds is not None: |
| attention_mask = attention_mask.to(inputs_embeds.dtype) |
| |
| encoder_outputs = self.VLV_model.model.language_model.model.encoder( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True |
| ) |
| |
| decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1) |
| decoder_attention_mask = torch.ones( |
| (batch_size, self.VLV_model.num_queries), |
| dtype=torch.float32, |
| device=self.device |
| ) |
|
|
| encoder_hidden_states = encoder_outputs.last_hidden_state.to(torch.float32) |
| decoder_input_embeds = decoder_inputs_embeds.to(torch.float32) |
| attention_mask = attention_mask.to(torch.float32) |
|
|
| decoder_outputs = self.VLV_model.model.language_model.model.decoder( |
| inputs_embeds=decoder_input_embeds, |
| attention_mask=decoder_attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True |
| ) |
| |
| return decoder_outputs.last_hidden_state |
| |
| def process_image(self, images, batch_size): |
| """ |
| Process images to get clip text embeddings. |
| |
| Args: |
| images: Input images |
| batch_size: Batch size |
| |
| Returns: |
| Processed clip text embeddings and attention mask |
| """ |
| decoder_hidden_states = self.get_conditional_context(images, batch_size) |
| context_embeds = self.VLV_model.language_proj(decoder_hidden_states) |
| |
| |
| context_attention_mask = torch.ones( |
| (batch_size, context_embeds.shape[1]), |
| dtype=torch.long, |
| device=self.device |
| ) |
| |
| clip_text_embeds = self.VLV_model.text_encoder( |
| inputs_embeds=context_embeds, |
| attention_mask=context_attention_mask |
| ).last_hidden_state |
| |
| clip_text_embeds = self.mlp(clip_text_embeds) |
| clip_text_embeds_attention_mask = torch.ones( |
| (batch_size, self.VLV_model.num_queries), |
| dtype=torch.long, |
| device=self.device |
| ) |
| |
| return clip_text_embeds, clip_text_embeds_attention_mask |
| |
| def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None): |
| """ |
| Prepare inputs for text generation. |
| |
| Args: |
| clip_text_embeds: Processed clip text embeddings |
| clip_text_attention_mask: Attention mask for clip text embeddings |
| |
| Returns: |
| Dictionary of generation inputs |
| """ |
| if clip_text_attention_mask is None: |
| clip_text_attention_mask = torch.ones( |
| (clip_text_embeds.shape[0], clip_text_embeds.shape[1]), |
| dtype=torch.long, |
| device=clip_text_embeds.device |
| ) |
| |
| return { |
| "inputs_embeds": clip_text_embeds, |
| "attention_mask": clip_text_attention_mask |
| } |
| |
| def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True): |
| """ |
| Generate text from images. |
| |
| Args: |
| images: Input images |
| max_new_tokens: Maximum number of tokens to generate |
| num_beams: Number of beams for beam search |
| early_stopping: Whether to stop early in beam search |
| |
| Returns: |
| CLIPDecoderOutput with generated ids and text |
| """ |
| batch_size = len(images) |
| clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size) |
| generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask) |
|
|
| generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(torch.bfloat16) |
| generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(torch.bfloat16) |
| |
| generated_ids = self.qwen2_model.generate( |
| inputs_embeds=generation_inputs["inputs_embeds"], |
| attention_mask=generation_inputs["attention_mask"], |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| early_stopping=early_stopping |
| ) |
| |
| generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| |
| return CLIPDecoderOutput( |
| generated_ids=generated_ids, |
| generated_text=generated_text |
| ) |
| |
| def forward(self, images, captions=None): |
| """ |
| Forward pass for training. |
| |
| Args: |
| images: Input images |
| captions: Target captions (optional, for training) |
| |
| Returns: |
| CLIPDecoderOutput with loss and logits |
| """ |
| batch_size = images.shape[0] |
| |
| |
| clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size) |
| |
| |
| if captions is None: |
| return CLIPDecoderOutput( |
| last_hidden_state=clip_text_embeds |
| ) |
| assert len(captions) == batch_size |
| |
| qwen_input_ids = self.qwen2_tokenizer( |
| text=captions, |
| truncation=True, |
| return_tensors="pt", |
| padding="max_length", |
| max_length=300, |
| return_token_type_ids=False, |
| ).input_ids |
| assert len(captions) == batch_size |
| qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device) |
| |
| |
| labels = qwen_input_ids |
| labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id |
| labels = labels.to(self.device) |
| |
| labels_for_embeddings = labels.clone() |
| labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id |
| clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings) |
| |
| |
| inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1) |
| clip_seq_len = clip_text_embeds.shape[1] |
| clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels) |
| combined_labels = torch.cat((clip_ignore_labels, labels), dim=1) |
| |
| attention_mask = torch.cat(( |
| clip_text_attention_mask, |
| qwen_attention_mask |
| ), dim=1) |
| |
| |
| outputs = self.qwen2_model( |
| inputs_embeds=inputs_embeds, |
| labels=combined_labels, |
| attention_mask=attention_mask, |
| use_cache=False |
| ) |
| return outputs |