import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Union, Tuple from transformers import Qwen2VLTextModel, Qwen2VLTextConfig, Qwen2VLPreTrainedModel, PretrainedConfig from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding from transformers.generation.utils import GenerationMixin from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import ModelOutput from PIL import Image, ImageOps from encoder import build_sam_vit_b, build_clip_l, MlpProjector from addict import Dict as ADict import os import math from data import ( format_messages, load_pil_images, text_encode, BasicImageTransform, dynamic_preprocess, re_match, process_image_with_refs, NoEOSTextStreamer, ) from tqdm import tqdm from dataclasses import dataclass class DeepQwenVLConfig(PretrainedConfig): """ Configuration class for DeepQwenVL model. This config wraps both the Qwen2VL text config and DeepSeek vision config. When loading from a Qwen2-VL checkpoint, it will use the checkpoint's config directly for the text model. """ model_type = "deepqwen_vl" def __init__( self, deepseek_vision_hidden_size: int = 2048, # Projector settings projector_type: str = "mlp", # "vision_projector" or "mlp" projector_input_dim: int = 2048, projector_output_dim: int = None, projector_hidden_dim: int = None, # If None, uses projector_output_dim # Learnable vision tokens image_newline_dim: int = None, # If None, uses hidden_size view_separator_dim: int = None, # If None, uses hidden_size hidden_size: int = 1536, intermediate_size: int = 8960, num_hidden_layers: int = 28, num_attention_heads: int = 12, num_key_value_heads: int = 2, hidden_act: str = "silu", max_position_embeddings: int = 32768, initializer_range: float = 0.02, rms_norm_eps: float = 1e-6, use_cache: bool = True, tie_word_embeddings: bool = True, rope_theta: float = 1000000.0, attention_dropout: float = 0.0, vocab_size: int = 151936, bos_token_id: int = 151643, eos_token_id: int = 151645, pad_token_id: int = 151643, image_token_id: int = 151655, video_token_id: int = 151656, vision_start_token_id: int = 151652, vision_end_token_id: int = 151653, vision_token_id: int = 151654, rope_scaling: dict = None, **kwargs ): super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) self.deepseek_vision_hidden_size = deepseek_vision_hidden_size # Projector settings self.projector_type = projector_type self.projector_input_dim = projector_input_dim self.projector_output_dim = projector_output_dim if projector_output_dim else hidden_size self.projector_hidden_dim = projector_hidden_dim if projector_hidden_dim else self.projector_output_dim # Learnable vision tokens self.image_newline_dim = image_newline_dim if image_newline_dim else hidden_size self.view_separator_dim = view_separator_dim if view_separator_dim else hidden_size # Text model settings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout self.vocab_size = vocab_size # Special tokens self.image_token_id = image_token_id self.video_token_id = video_token_id self.vision_start_token_id = vision_start_token_id self.vision_end_token_id = vision_end_token_id self.vision_token_id = vision_token_id # Rope scaling if rope_scaling is None: rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} self.rope_scaling = rope_scaling def to_text_config(self) -> Qwen2VLTextConfig: """Convert to Qwen2VLTextConfig for the text model.""" return Qwen2VLTextConfig( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads, hidden_act=self.hidden_act, max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, rms_norm_eps=self.rms_norm_eps, use_cache=self.use_cache, tie_word_embeddings=self.tie_word_embeddings, rope_theta=self.rope_theta, attention_dropout=self.attention_dropout, vocab_size=self.vocab_size, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, rope_scaling=self.rope_scaling, ) @dataclass class DeepQwenOutputWithPast(ModelOutput): last_hidden_state: torch.FloatTensor = None past_key_values: Optional[list[torch.FloatTensor]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None @dataclass class DeepQwenCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[list[torch.FloatTensor]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None class VisionProjector(nn.Module): """ Vision projector with DeepSeek's pretrained layer + trainable adapter. Architecture: deepseek_proj: Linear(2048→1280) [FROZEN - loaded from DeepSeek checkpoint] SiLU activation norm: LayerNorm(1280) [TRAINABLE] adapter: Linear(1280→1536) [TRAINABLE] This preserves DeepSeek's learned vision-text alignment while adapting to Qwen's embedding space. Total 2 layers like LLaVA's MLP projector. """ def __init__(self, input_dim: int = 2048, hidden_dim: int = 1280, output_dim: int = 1536): super().__init__() # DeepSeek's original projection (will be frozen after loading weights) self.deepseek_proj = nn.Linear(input_dim, hidden_dim) # Adapter for Qwen (trainable) self.norm = nn.LayerNorm(hidden_dim) self.adapter = nn.Linear(hidden_dim, output_dim) self._init_adapter_weights() def _init_adapter_weights(self): """Initialize adapter weights. deepseek_proj will be loaded from checkpoint.""" nn.init.ones_(self.norm.weight) nn.init.zeros_(self.norm.bias) nn.init.normal_(self.adapter.weight, mean=0.0, std=0.01) nn.init.zeros_(self.adapter.bias) def forward(self, x): x = self.deepseek_proj(x) x = F.silu(x) x = self.norm(x) x = self.adapter(x) return x class DeepQwenVLPreTrainedModel(PreTrainedModel): config_class = DeepQwenVLConfig base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True _supports_static_cache = True _supports_attention_backend = True _keys_to_ignore_on_load_missing = [ "sam_model", "vision_model", "projector", "image_newline", "view_separator", ] def _init_weights(self, module): """Initialize the weights.""" std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) class DeepQwenVLModel(Qwen2VLTextModel): """ DeepQwenVL Model that combines DeepSeek's vision encoders with Qwen2VL's text model. Accepts either: - A DeepQwenVLConfig - A Qwen2VLTextConfig (for compatibility with from_pretrained from Qwen checkpoints) - A generic PretrainedConfig (will extract necessary fields) """ config_class = DeepQwenVLConfig def __init__(self, config): if isinstance(config, DeepQwenVLConfig): text_config = config.to_text_config() output_hidden_size = config.projector_output_dim vision_dim = config.deepseek_vision_hidden_size elif isinstance(config, Qwen2VLTextConfig): text_config = config output_hidden_size = config.hidden_size vision_dim = 2048 else: text_config = config output_hidden_size = getattr(config, 'hidden_size', 1536) vision_dim = getattr(config, 'deepseek_vision_hidden_size', 2048) super(DeepQwenVLModel, self).__init__(text_config) self.config = config self.output_hidden_size = output_hidden_size self.sam_model = build_sam_vit_b() self.vision_model = build_clip_l() self.deepseek_vision_dim = vision_dim self.deepseek_hidden_dim = 1280 # DeepSeek's projector output dimension # New projector: DeepSeek layer (frozen) + adapter (trainable) self.projector = VisionProjector( input_dim=self.deepseek_vision_dim, # 2048 hidden_dim=self.deepseek_hidden_dim, # 1280 (DeepSeek's output) output_dim=output_hidden_size # 1536 (Qwen's hidden size) ) embed_std = 1 / torch.sqrt(torch.tensor(output_hidden_size, dtype=torch.float32)) self.image_newline = nn.Parameter(torch.randn(output_hidden_size) * embed_std) self.view_separator = nn.Parameter(torch.randn(output_hidden_size) * embed_std) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.FloatTensor] = None, images_spatial_crop: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) sam_model = getattr(self, 'sam_model', None) vision_model = getattr(self, 'vision_model', None) should_process_images = ( sam_model is not None and images is not None and images_seq_mask is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0 ) if should_process_images: idx = 0 for image, crop_shape in zip(images, images_spatial_crop): images_in_this_batch = [] patches = image[0] image_ori = image[1] if torch.sum(patches).item() != 0: # Process local patches with torch.no_grad(): local_features_1 = sam_model(patches) local_features_2 = vision_model(patches, local_features_1) local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) local_features = local_features.detach() local_features = self.projector(local_features) # Process global image with torch.no_grad(): global_features_1 = sam_model(image_ori) global_features_2 = vision_model(image_ori, global_features_1) global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features = global_features.detach() global_features = self.projector(global_features) # Reshape and add newline tokens _, hw, n_dim = global_features.shape h = w = int(hw ** 0.5) _2, hw2, n_dim2 = local_features.shape h2 = w2 = int(hw2 ** 0.5) width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] global_features = global_features.view(h, w, n_dim) global_features = torch.cat( [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 ) global_features = global_features.view(-1, n_dim) local_features = local_features.view( height_crop_num, width_crop_num, h2, w2, n_dim2 ).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) local_features = torch.cat( [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 ) local_features = local_features.view(-1, n_dim2) global_local_features = torch.cat([local_features, global_features, self.view_separator[None, :]], dim=0) images_in_this_batch.append(global_local_features) else: # Global-only branch (small images) with torch.no_grad(): global_features_1 = sam_model(image_ori) global_features_2 = vision_model(image_ori, global_features_1) global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features = global_features.detach() global_features = self.projector(global_features) _, hw, n_dim = global_features.shape h = w = int(hw ** 0.5) global_features = global_features.view(h, w, n_dim) global_features = torch.cat( [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 ) global_features = global_features.view(-1, n_dim) global_local_features = torch.cat([global_features, self.view_separator[None, :]], dim=0) images_in_this_batch.append(global_local_features) if images_in_this_batch: images_in_this_batch = torch.cat(images_in_this_batch, dim=0) inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) idx += 1 outputs = super().forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position ) return DeepQwenOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) if return_dict else outputs.to_tuple() class DeepQwenVLForCausalLM(DeepQwenVLModel, GenerationMixin): """ DeepQwenVL Model for causal language modeling with vision capabilities. Combines DeepSeek's vision encoders (SAM + CLIP) with Qwen2VL's text model. """ config_class = DeepQwenVLConfig _tied_weights_keys = ["lm_head.weight"] _keys_to_ignore_on_load_missing = [ # "sam_model", # "vision_model", # "projector", # "image_newline", # "view_separator", ] def __init__(self, config): """ Initialize the model. Args: config: Can be DeepQwenVLConfig, Qwen2VLTextConfig, or a generic config from a Qwen2-VL checkpoint. """ super().__init__(config) hidden_size = getattr(config, 'hidden_size', 1536) vocab_size = getattr(config, 'vocab_size', 151936) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) self.post_init() def get_output_embeddings(self): return getattr(self, 'lm_head', None) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.FloatTensor] = None, images_spatial_crop: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, images=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, return_dict=True, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) return DeepQwenCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, images=None, images_seq_mask=None, images_spatial_crop=None, **kwargs, ): model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, **kwargs, ) model_inputs["images"] = images model_inputs["images_seq_mask"] = images_seq_mask model_inputs["images_spatial_crop"] = images_spatial_crop model_inputs["position_ids"] = None # Clear images after first forward pass (cache_position[0] != 0 means subsequent tokens) if cache_position is not None and cache_position[0] != 0: model_inputs["images"] = None model_inputs["images_seq_mask"] = None model_inputs["images_spatial_crop"] = None return model_inputs def reinitialize_projector(self, vis_mlp=None, device=None, dtype=None): """ Reinitialize the projector, image_newline, and view_separator. Call this after from_pretrained when loading from a Qwen checkpoint. """ if device is None: for param in self.parameters(): if param.device.type != 'meta': device = param.device break if device is None: device = 'cpu' if dtype is None: dtype = torch.bfloat16 input_dim = self.deepseek_vision_dim output_dim = self.output_hidden_size if vis_mlp is not None: self.projector = VisionProjector(input_dim=input_dim, output_dim=output_dim).to(device=device, dtype=dtype) else: self.projector = nn.Linear(in_features=input_dim, out_features=output_dim).to(device=device, dtype=dtype) nn.init.normal_(self.projector.weight, mean=0.0, std=0.01) if self.projector.bias is not None: nn.init.zeros_(self.projector.bias) embed_std = 1 / torch.sqrt(torch.tensor(output_dim, dtype=torch.float32)) self.image_newline = nn.Parameter( torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item() ) self.view_separator = nn.Parameter( torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item() ) print(f"Projector reinitialized on {device} with dtype {dtype}") def load_pretrained_vision(self, pretrained_path: str): try: from safetensors import safe_open except ImportError: raise ImportError("Please install safetensors to load the pretrained vision model.") assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist." vision_weights = {} with safe_open(f"{pretrained_path}/model-00001-of-000001.safetensors", framework="pt", device="cpu") as f: for k in f.keys(): vision_weights[k] = f.get_tensor(k) prefixes = { "sam_model": "model.sam_model.", "vision_model": "model.vision_model.", } try: for p in prefixes.keys(): state_dict = {} for k, v in vision_weights.items(): if k.startswith(prefixes[p]): new_key = k[len(prefixes[p]):] state_dict[new_key] = v getattr(self, p).load_state_dict(state_dict, strict=False) print("Pretrained vision model loaded successfully.") except Exception as e: print("Error loading pretrained vision model:", e) raise e def load_deepseek_projector(self, pretrained_path: str): """ Load DeepSeek's projector weights into the deepseek_proj layer. DeepSeek checkpoint has: - projector.weight: shape (1280, 2048) - projector.bias: shape (1280,) These get loaded into self.projector.deepseek_proj """ try: from safetensors import safe_open except ImportError: raise ImportError("Please install safetensors to load DeepSeek projector.") assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist." # Find safetensors file safetensor_files = [f for f in os.listdir(pretrained_path) if f.endswith('.safetensors')] if not safetensor_files: raise FileNotFoundError(f"No safetensors files found in {pretrained_path}") safetensor_path = os.path.join(pretrained_path, safetensor_files[0]) projector_weights = {} with safe_open(safetensor_path, framework="pt", device="cpu") as f: for k in f.keys(): if 'projector' in k: projector_weights[k] = f.get_tensor(k) # Load into deepseek_proj if 'projector.weight' in projector_weights: self.projector.deepseek_proj.weight.data = projector_weights['projector.weight'] self.projector.deepseek_proj.bias.data = projector_weights['projector.bias'] print(f"Loaded DeepSeek projector weights: {self.projector.deepseek_proj.weight.shape}") print(f" Weight mean: {self.projector.deepseek_proj.weight.mean().item():.6f}") print(f" Weight std: {self.projector.deepseek_proj.weight.std().item():.6f}") elif 'model.projector.weight' in projector_weights: self.projector.deepseek_proj.weight.data = projector_weights['model.projector.weight'] self.projector.deepseek_proj.bias.data = projector_weights['model.projector.bias'] print(f"Loaded DeepSeek projector weights (model. prefix)") else: print(f"Warning: Could not find projector weights. Available keys: {list(projector_weights.keys())}") def disable_torch_init(self): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def infer( self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False ): self.disable_torch_init() os.makedirs(output_path, exist_ok=True) os.makedirs(f'{output_path}/images', exist_ok=True) conversation = [ { "role": "user", "content": [ { "type": "image", "image": f"{image_file}", }, {"type": "text", "text": f"{prompt}"}, ], } ] formatted_prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) patch_size = 16 downsample_ratio = 4 images = load_pil_images(conversation) valid_img_tokens = 0 ratio = 1 image_draw = images[0].copy() w,h = image_draw.size ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) images_seq_mask = [] image_token = '<|image_pad|>' image_token_id = 151655 text_splits = formatted_prompt.split(image_token) images_list, images_crop_list, images_seq_mask = [], [], [] tokenized_str = [] images_spatial_crop = [] for text_sep, image in zip(text_splits, images): tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) if crop_mode: if image.size[0] <= 640 and image.size[1] <= 640: crop_ratio = [1, 1] else: if crop_mode: images_crop_raw, crop_ratio = dynamic_preprocess(image) else: crop_ratio = [1, 1] global_view = ImageOps.pad(image, (base_size, base_size), color=tuple(int(x * 255) for x in image_transform.mean)) if base_size == 1024: valid_img_tokens += int(256 * ratio) elif base_size == 1280: valid_img_tokens += int(400 * ratio) # elif base_size == 640: # valid_img_tokens += int(100 * ratio) images_list.append(image_transform(global_view).to(torch.bfloat16)) # global_view_tensor = image_transform(global_view).to(torch.bfloat16) width_crop_num, height_crop_num = crop_ratio images_spatial_crop.append([width_crop_num, height_crop_num]) if width_crop_num > 1 or height_crop_num > 1: """process the local views""" for i in range(len(images_crop_raw)): images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) if image_size == 640: valid_img_tokens += len(images_crop_list) * 100 num_queries = math.ceil((image_size // patch_size) / downsample_ratio) num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) """add image tokens""" tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base tokenized_image += [image_token_id] if width_crop_num > 1 or height_crop_num > 1: tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( num_queries * height_crop_num) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # num_image_tokens.append(len(tokenized_image)) else: """process the global view""" if image_size <= 640: image = image.resize((image_size, image_size)) global_view = ImageOps.pad(image, (image_size, image_size), color=tuple(int(x * 255) for x in image_transform.mean)) images_list.append(image_transform(global_view).to(torch.bfloat16)) if base_size == 1024: valid_img_tokens += int(256 * ratio) elif base_size == 1280: valid_img_tokens += int(400 * ratio) elif base_size == 640: valid_img_tokens += int(100 * 1) elif base_size == 512: valid_img_tokens += int(64 * 1) width_crop_num, height_crop_num = 1, 1 images_spatial_crop.append([width_crop_num, height_crop_num]) """add image tokens""" num_queries = math.ceil((image_size // patch_size) / downsample_ratio) tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries tokenized_image += [image_token_id] # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( # num_queries * height_crop_num) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # num_image_tokens.append(len(tokenized_image)) """process the last text split""" tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) # Qwen2VL has NO bos_token (bos_token_id is None) # The chat template already handles proper formatting input_ids = torch.LongTensor(tokenized_str) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) if len(images_list) == 0: images_ori = torch.zeros((1, 3, image_size, image_size)) images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) images_crop = torch.zeros((1, 3, base_size, base_size)) else: images_ori = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) if images_crop_list: images_crop = torch.stack(images_crop_list, dim=0) else: images_crop = torch.zeros((1, 3, base_size, base_size)) if not eval_mode: streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) with torch.autocast("cuda", dtype=torch.bfloat16): with torch.no_grad(): output_ids = self.generate( input_ids.unsqueeze(0).cuda(), images=[(images_crop.cuda(), images_ori.cuda())], images_seq_mask=images_seq_mask.unsqueeze(0).cuda(), images_spatial_crop=images_spatial_crop, temperature=0.5, eos_token_id=tokenizer.eos_token_id, streamer=streamer, max_new_tokens=8192, no_repeat_ngram_size=20, use_cache=True ) else: with torch.autocast("cuda", dtype=torch.bfloat16): with torch.no_grad(): output_ids = self.generate( input_ids.unsqueeze(0).cuda(), images=[(images_crop.cuda(), images_ori.cuda())], images_seq_mask=images_seq_mask.unsqueeze(0).cuda(), images_spatial_crop=images_spatial_crop, temperature=0.5, eos_token_id=tokenizer.eos_token_id, max_new_tokens=8192, no_repeat_ngram_size=35, use_cache=True ) # Check if conversation has image has_image = any( (isinstance(item, dict) and item.get('type') == 'image') for msg in conversation for item in (msg.get('content', []) if isinstance(msg.get('content'), list) else []) ) if has_image and eval_mode: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) # Qwen2VL's EOS token is <|im_end|> stop_str = tokenizer.eos_token or '<|im_end|>' if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() return outputs if has_image and test_compress: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) print('='*50) print('image size: ', (w, h)) print('valid image tokens: ', int(valid_img_tokens)) print('output texts tokens (valid): ', pure_texts_outputs_token_length) print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) print('='*50) if has_image and save_results: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) # Qwen2VL's EOS token stop_str = tokenizer.eos_token or '<|im_end|>' print('='*15 + 'save results:' + '='*15) if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() matches_ref, matches_images, mathes_other = re_match(outputs) result = process_image_with_refs(image_draw, matches_ref, output_path) for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: afile.write(outputs) if 'line_type' in outputs: import matplotlib.pyplot as plt lines = eval(outputs)['Line']['line'] line_type = eval(outputs)['Line']['line_type'] endpoints = eval(outputs)['Line']['line_endpoint'] fig, ax = plt.subplots(figsize=(3,3), dpi=200) ax.set_xlim(-15, 15) ax.set_ylim(-15, 15) for idx, line in enumerate(lines): try: p0 = eval(line.split(' -- ')[0]) p1 = eval(line.split(' -- ')[-1]) if line_type[idx] == '--': ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') else: ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') ax.scatter(p0[0], p0[1], s=5, color = 'k') ax.scatter(p1[0], p1[1], s=5, color = 'k') except: pass for endpoint in endpoints: label = endpoint.split(': ')[0] (x, y) = eval(endpoint.split(': ')[1]) ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', fontsize=5, fontweight='light') plt.savefig(f'{output_path}/geo.jpg') plt.close() result.save(f"{output_path}/result_with_boxes.jpg") ## TODO # new training loop: ## image -> vision encoder -> projection ->! txt_decoder -> embedding -> pool # => alignment(text_pooling, image_pooling) ## text -> text encoder -> projection -> embedding -> pool ## cant let projection layer output into text decoder