DeepQwenVL-Base / model.py
Sibgat-Ul's picture
Upload folder using huggingface_hub
32d2edb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Union, Tuple
from transformers import Qwen2VLTextModel, Qwen2VLTextConfig, Qwen2VLPreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
from transformers.generation.utils import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from PIL import Image, ImageOps
from encoder import build_sam_vit_b, build_clip_l, MlpProjector
from addict import Dict as ADict
import os
import math
from data import (
format_messages,
load_pil_images,
text_encode,
BasicImageTransform,
dynamic_preprocess,
re_match,
process_image_with_refs,
NoEOSTextStreamer,
)
from tqdm import tqdm
from dataclasses import dataclass
class DeepQwenVLConfig(PretrainedConfig):
"""
Configuration class for DeepQwenVL model.
This config wraps both the Qwen2VL text config and DeepSeek vision config.
When loading from a Qwen2-VL checkpoint, it will use the checkpoint's config
directly for the text model.
"""
model_type = "deepqwen_vl"
def __init__(
self,
deepseek_vision_hidden_size: int = 2048,
# Projector settings
projector_type: str = "mlp", # "vision_projector" or "mlp"
projector_input_dim: int = 2048,
projector_output_dim: int = None,
projector_hidden_dim: int = None, # If None, uses projector_output_dim
# Learnable vision tokens
image_newline_dim: int = None, # If None, uses hidden_size
view_separator_dim: int = None, # If None, uses hidden_size
hidden_size: int = 1536,
intermediate_size: int = 8960,
num_hidden_layers: int = 28,
num_attention_heads: int = 12,
num_key_value_heads: int = 2,
hidden_act: str = "silu",
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
tie_word_embeddings: bool = True,
rope_theta: float = 1000000.0,
attention_dropout: float = 0.0,
vocab_size: int = 151936,
bos_token_id: int = 151643,
eos_token_id: int = 151645,
pad_token_id: int = 151643,
image_token_id: int = 151655,
video_token_id: int = 151656,
vision_start_token_id: int = 151652,
vision_end_token_id: int = 151653,
vision_token_id: int = 151654,
rope_scaling: dict = None,
**kwargs
):
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
self.deepseek_vision_hidden_size = deepseek_vision_hidden_size
# Projector settings
self.projector_type = projector_type
self.projector_input_dim = projector_input_dim
self.projector_output_dim = projector_output_dim if projector_output_dim else hidden_size
self.projector_hidden_dim = projector_hidden_dim if projector_hidden_dim else self.projector_output_dim
# Learnable vision tokens
self.image_newline_dim = image_newline_dim if image_newline_dim else hidden_size
self.view_separator_dim = view_separator_dim if view_separator_dim else hidden_size
# Text model settings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.vocab_size = vocab_size
# Special tokens
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
self.vision_token_id = vision_token_id
# Rope scaling
if rope_scaling is None:
rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]}
self.rope_scaling = rope_scaling
def to_text_config(self) -> Qwen2VLTextConfig:
"""Convert to Qwen2VLTextConfig for the text model."""
return Qwen2VLTextConfig(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
hidden_act=self.hidden_act,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
rms_norm_eps=self.rms_norm_eps,
use_cache=self.use_cache,
tie_word_embeddings=self.tie_word_embeddings,
rope_theta=self.rope_theta,
attention_dropout=self.attention_dropout,
vocab_size=self.vocab_size,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
rope_scaling=self.rope_scaling,
)
@dataclass
class DeepQwenOutputWithPast(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[list[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
@dataclass
class DeepQwenCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[list[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
class VisionProjector(nn.Module):
"""
Vision projector with DeepSeek's pretrained layer + trainable adapter.
Architecture:
deepseek_proj: Linear(2048→1280) [FROZEN - loaded from DeepSeek checkpoint]
SiLU activation
norm: LayerNorm(1280) [TRAINABLE]
adapter: Linear(1280→1536) [TRAINABLE]
This preserves DeepSeek's learned vision-text alignment while adapting to Qwen's
embedding space. Total 2 layers like LLaVA's MLP projector.
"""
def __init__(self, input_dim: int = 2048, hidden_dim: int = 1280, output_dim: int = 1536):
super().__init__()
# DeepSeek's original projection (will be frozen after loading weights)
self.deepseek_proj = nn.Linear(input_dim, hidden_dim)
# Adapter for Qwen (trainable)
self.norm = nn.LayerNorm(hidden_dim)
self.adapter = nn.Linear(hidden_dim, output_dim)
self._init_adapter_weights()
def _init_adapter_weights(self):
"""Initialize adapter weights. deepseek_proj will be loaded from checkpoint."""
nn.init.ones_(self.norm.weight)
nn.init.zeros_(self.norm.bias)
nn.init.normal_(self.adapter.weight, mean=0.0, std=0.01)
nn.init.zeros_(self.adapter.bias)
def forward(self, x):
x = self.deepseek_proj(x)
x = F.silu(x)
x = self.norm(x)
x = self.adapter(x)
return x
class DeepQwenVLPreTrainedModel(PreTrainedModel):
config_class = DeepQwenVLConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
_supports_static_cache = True
_supports_attention_backend = True
_keys_to_ignore_on_load_missing = [
"sam_model",
"vision_model",
"projector",
"image_newline",
"view_separator",
]
def _init_weights(self, module):
"""Initialize the weights."""
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
class DeepQwenVLModel(Qwen2VLTextModel):
"""
DeepQwenVL Model that combines DeepSeek's vision encoders with Qwen2VL's text model.
Accepts either:
- A DeepQwenVLConfig
- A Qwen2VLTextConfig (for compatibility with from_pretrained from Qwen checkpoints)
- A generic PretrainedConfig (will extract necessary fields)
"""
config_class = DeepQwenVLConfig
def __init__(self, config):
if isinstance(config, DeepQwenVLConfig):
text_config = config.to_text_config()
output_hidden_size = config.projector_output_dim
vision_dim = config.deepseek_vision_hidden_size
elif isinstance(config, Qwen2VLTextConfig):
text_config = config
output_hidden_size = config.hidden_size
vision_dim = 2048
else:
text_config = config
output_hidden_size = getattr(config, 'hidden_size', 1536)
vision_dim = getattr(config, 'deepseek_vision_hidden_size', 2048)
super(DeepQwenVLModel, self).__init__(text_config)
self.config = config
self.output_hidden_size = output_hidden_size
self.sam_model = build_sam_vit_b()
self.vision_model = build_clip_l()
self.deepseek_vision_dim = vision_dim
self.deepseek_hidden_dim = 1280 # DeepSeek's projector output dimension
# New projector: DeepSeek layer (frozen) + adapter (trainable)
self.projector = VisionProjector(
input_dim=self.deepseek_vision_dim, # 2048
hidden_dim=self.deepseek_hidden_dim, # 1280 (DeepSeek's output)
output_dim=output_hidden_size # 1536 (Qwen's hidden size)
)
embed_std = 1 / torch.sqrt(torch.tensor(output_hidden_size, dtype=torch.float32))
self.image_newline = nn.Parameter(torch.randn(output_hidden_size) * embed_std)
self.view_separator = nn.Parameter(torch.randn(output_hidden_size) * embed_std)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_seq_mask: Optional[torch.FloatTensor] = None,
images_spatial_crop: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
sam_model = getattr(self, 'sam_model', None)
vision_model = getattr(self, 'vision_model', None)
should_process_images = (
sam_model is not None
and images is not None
and images_seq_mask is not None
and (input_ids.shape[1] != 1 or self.training)
and torch.sum(images[0][1]).item() != 0
)
if should_process_images:
idx = 0
for image, crop_shape in zip(images, images_spatial_crop):
images_in_this_batch = []
patches = image[0]
image_ori = image[1]
if torch.sum(patches).item() != 0:
# Process local patches
with torch.no_grad():
local_features_1 = sam_model(patches)
local_features_2 = vision_model(patches, local_features_1)
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
local_features = local_features.detach()
local_features = self.projector(local_features)
# Process global image
with torch.no_grad():
global_features_1 = sam_model(image_ori)
global_features_2 = vision_model(image_ori, global_features_1)
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
global_features = global_features.detach()
global_features = self.projector(global_features)
# Reshape and add newline tokens
_, hw, n_dim = global_features.shape
h = w = int(hw ** 0.5)
_2, hw2, n_dim2 = local_features.shape
h2 = w2 = int(hw2 ** 0.5)
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
)
global_features = global_features.view(-1, n_dim)
local_features = local_features.view(
height_crop_num, width_crop_num, h2, w2, n_dim2
).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
local_features = torch.cat(
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
)
local_features = local_features.view(-1, n_dim2)
global_local_features = torch.cat([local_features, global_features, self.view_separator[None, :]], dim=0)
images_in_this_batch.append(global_local_features)
else:
# Global-only branch (small images)
with torch.no_grad():
global_features_1 = sam_model(image_ori)
global_features_2 = vision_model(image_ori, global_features_1)
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
global_features = global_features.detach()
global_features = self.projector(global_features)
_, hw, n_dim = global_features.shape
h = w = int(hw ** 0.5)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
)
global_features = global_features.view(-1, n_dim)
global_local_features = torch.cat([global_features, self.view_separator[None, :]], dim=0)
images_in_this_batch.append(global_local_features)
if images_in_this_batch:
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
idx += 1
outputs = super().forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict, cache_position=cache_position
)
return DeepQwenOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) if return_dict else outputs.to_tuple()
class DeepQwenVLForCausalLM(DeepQwenVLModel, GenerationMixin):
"""
DeepQwenVL Model for causal language modeling with vision capabilities.
Combines DeepSeek's vision encoders (SAM + CLIP) with Qwen2VL's text model.
"""
config_class = DeepQwenVLConfig
_tied_weights_keys = ["lm_head.weight"]
_keys_to_ignore_on_load_missing = [
# "sam_model",
# "vision_model",
# "projector",
# "image_newline",
# "view_separator",
]
def __init__(self, config):
"""
Initialize the model.
Args:
config: Can be DeepQwenVLConfig, Qwen2VLTextConfig, or a generic config
from a Qwen2-VL checkpoint.
"""
super().__init__(config)
hidden_size = getattr(config, 'hidden_size', 1536)
vocab_size = getattr(config, 'vocab_size', 151936)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return getattr(self, 'lm_head', None)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
labels: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_seq_mask: Optional[torch.FloatTensor] = None,
images_spatial_crop: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
position_ids = position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
return_dict=True,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return DeepQwenCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
images=None,
images_seq_mask=None,
images_spatial_crop=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
**kwargs,
)
model_inputs["images"] = images
model_inputs["images_seq_mask"] = images_seq_mask
model_inputs["images_spatial_crop"] = images_spatial_crop
model_inputs["position_ids"] = None
# Clear images after first forward pass (cache_position[0] != 0 means subsequent tokens)
if cache_position is not None and cache_position[0] != 0:
model_inputs["images"] = None
model_inputs["images_seq_mask"] = None
model_inputs["images_spatial_crop"] = None
return model_inputs
def reinitialize_projector(self, vis_mlp=None, device=None, dtype=None):
"""
Reinitialize the projector, image_newline, and view_separator.
Call this after from_pretrained when loading from a Qwen checkpoint.
"""
if device is None:
for param in self.parameters():
if param.device.type != 'meta':
device = param.device
break
if device is None:
device = 'cpu'
if dtype is None:
dtype = torch.bfloat16
input_dim = self.deepseek_vision_dim
output_dim = self.output_hidden_size
if vis_mlp is not None:
self.projector = VisionProjector(input_dim=input_dim, output_dim=output_dim).to(device=device, dtype=dtype)
else:
self.projector = nn.Linear(in_features=input_dim, out_features=output_dim).to(device=device, dtype=dtype)
nn.init.normal_(self.projector.weight, mean=0.0, std=0.01)
if self.projector.bias is not None:
nn.init.zeros_(self.projector.bias)
embed_std = 1 / torch.sqrt(torch.tensor(output_dim, dtype=torch.float32))
self.image_newline = nn.Parameter(
torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item()
)
self.view_separator = nn.Parameter(
torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item()
)
print(f"Projector reinitialized on {device} with dtype {dtype}")
def load_pretrained_vision(self, pretrained_path: str):
try:
from safetensors import safe_open
except ImportError:
raise ImportError("Please install safetensors to load the pretrained vision model.")
assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist."
vision_weights = {}
with safe_open(f"{pretrained_path}/model-00001-of-000001.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
vision_weights[k] = f.get_tensor(k)
prefixes = {
"sam_model": "model.sam_model.",
"vision_model": "model.vision_model.",
}
try:
for p in prefixes.keys():
state_dict = {}
for k, v in vision_weights.items():
if k.startswith(prefixes[p]):
new_key = k[len(prefixes[p]):]
state_dict[new_key] = v
getattr(self, p).load_state_dict(state_dict, strict=False)
print("Pretrained vision model loaded successfully.")
except Exception as e:
print("Error loading pretrained vision model:", e)
raise e
def load_deepseek_projector(self, pretrained_path: str):
"""
Load DeepSeek's projector weights into the deepseek_proj layer.
DeepSeek checkpoint has:
- projector.weight: shape (1280, 2048)
- projector.bias: shape (1280,)
These get loaded into self.projector.deepseek_proj
"""
try:
from safetensors import safe_open
except ImportError:
raise ImportError("Please install safetensors to load DeepSeek projector.")
assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist."
# Find safetensors file
safetensor_files = [f for f in os.listdir(pretrained_path) if f.endswith('.safetensors')]
if not safetensor_files:
raise FileNotFoundError(f"No safetensors files found in {pretrained_path}")
safetensor_path = os.path.join(pretrained_path, safetensor_files[0])
projector_weights = {}
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
if 'projector' in k:
projector_weights[k] = f.get_tensor(k)
# Load into deepseek_proj
if 'projector.weight' in projector_weights:
self.projector.deepseek_proj.weight.data = projector_weights['projector.weight']
self.projector.deepseek_proj.bias.data = projector_weights['projector.bias']
print(f"Loaded DeepSeek projector weights: {self.projector.deepseek_proj.weight.shape}")
print(f" Weight mean: {self.projector.deepseek_proj.weight.mean().item():.6f}")
print(f" Weight std: {self.projector.deepseek_proj.weight.std().item():.6f}")
elif 'model.projector.weight' in projector_weights:
self.projector.deepseek_proj.weight.data = projector_weights['model.projector.weight']
self.projector.deepseek_proj.bias.data = projector_weights['model.projector.bias']
print(f"Loaded DeepSeek projector weights (model. prefix)")
else:
print(f"Warning: Could not find projector weights. Available keys: {list(projector_weights.keys())}")
def disable_torch_init(self):
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def infer(
self,
tokenizer,
prompt='',
image_file='',
output_path = '',
base_size=1024,
image_size=640,
crop_mode=True,
test_compress=False,
save_results=False,
eval_mode=False
):
self.disable_torch_init()
os.makedirs(output_path, exist_ok=True)
os.makedirs(f'{output_path}/images', exist_ok=True)
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"{image_file}",
},
{"type": "text", "text": f"{prompt}"},
],
}
]
formatted_prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
patch_size = 16
downsample_ratio = 4
images = load_pil_images(conversation)
valid_img_tokens = 0
ratio = 1
image_draw = images[0].copy()
w,h = image_draw.size
ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
images_seq_mask = []
image_token = '<|image_pad|>'
image_token_id = 151655
text_splits = formatted_prompt.split(image_token)
images_list, images_crop_list, images_seq_mask = [], [], []
tokenized_str = []
images_spatial_crop = []
for text_sep, image in zip(text_splits, images):
tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
if crop_mode:
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
if crop_mode:
images_crop_raw, crop_ratio = dynamic_preprocess(image)
else:
crop_ratio = [1, 1]
global_view = ImageOps.pad(image, (base_size, base_size),
color=tuple(int(x * 255) for x in image_transform.mean))
if base_size == 1024:
valid_img_tokens += int(256 * ratio)
elif base_size == 1280:
valid_img_tokens += int(400 * ratio)
# elif base_size == 640:
# valid_img_tokens += int(100 * ratio)
images_list.append(image_transform(global_view).to(torch.bfloat16))
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
width_crop_num, height_crop_num = crop_ratio
images_spatial_crop.append([width_crop_num, height_crop_num])
if width_crop_num > 1 or height_crop_num > 1:
"""process the local views"""
for i in range(len(images_crop_raw)):
images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
if image_size == 640:
valid_img_tokens += len(images_crop_list) * 100
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
"""add image tokens"""
tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
tokenized_image += [image_token_id]
if width_crop_num > 1 or height_crop_num > 1:
tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
num_queries * height_crop_num)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
# num_image_tokens.append(len(tokenized_image))
else:
"""process the global view"""
if image_size <= 640:
image = image.resize((image_size, image_size))
global_view = ImageOps.pad(image, (image_size, image_size),
color=tuple(int(x * 255) for x in image_transform.mean))
images_list.append(image_transform(global_view).to(torch.bfloat16))
if base_size == 1024:
valid_img_tokens += int(256 * ratio)
elif base_size == 1280:
valid_img_tokens += int(400 * ratio)
elif base_size == 640:
valid_img_tokens += int(100 * 1)
elif base_size == 512:
valid_img_tokens += int(64 * 1)
width_crop_num, height_crop_num = 1, 1
images_spatial_crop.append([width_crop_num, height_crop_num])
"""add image tokens"""
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
tokenized_image += [image_token_id]
# tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
# num_queries * height_crop_num)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
# num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
# Qwen2VL has NO bos_token (bos_token_id is None)
# The chat template already handles proper formatting
input_ids = torch.LongTensor(tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
if len(images_list) == 0:
images_ori = torch.zeros((1, 3, image_size, image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
images_crop = torch.zeros((1, 3, base_size, base_size))
else:
images_ori = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0)
else:
images_crop = torch.zeros((1, 3, base_size, base_size))
if not eval_mode:
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.no_grad():
output_ids = self.generate(
input_ids.unsqueeze(0).cuda(),
images=[(images_crop.cuda(), images_ori.cuda())],
images_seq_mask=images_seq_mask.unsqueeze(0).cuda(),
images_spatial_crop=images_spatial_crop,
temperature=0.5,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer,
max_new_tokens=8192,
no_repeat_ngram_size=20,
use_cache=True
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.no_grad():
output_ids = self.generate(
input_ids.unsqueeze(0).cuda(),
images=[(images_crop.cuda(), images_ori.cuda())],
images_seq_mask=images_seq_mask.unsqueeze(0).cuda(),
images_spatial_crop=images_spatial_crop,
temperature=0.5,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=8192,
no_repeat_ngram_size=35,
use_cache=True
)
# Check if conversation has image
has_image = any(
(isinstance(item, dict) and item.get('type') == 'image')
for msg in conversation
for item in (msg.get('content', []) if isinstance(msg.get('content'), list) else [])
)
if has_image and eval_mode:
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False)
# Qwen2VL's EOS token is <|im_end|>
stop_str = tokenizer.eos_token or '<|im_end|>'
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
return outputs
if has_image and test_compress:
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False)
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
print('='*50)
print('image size: ', (w, h))
print('valid image tokens: ', int(valid_img_tokens))
print('output texts tokens (valid): ', pure_texts_outputs_token_length)
print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
print('='*50)
if has_image and save_results:
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False)
# Qwen2VL's EOS token
stop_str = tokenizer.eos_token or '<|im_end|>'
print('='*15 + 'save results:' + '='*15)
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
matches_ref, matches_images, mathes_other = re_match(outputs)
result = process_image_with_refs(image_draw, matches_ref, output_path)
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
if 'line_type' in outputs:
import matplotlib.pyplot as plt
lines = eval(outputs)['Line']['line']
line_type = eval(outputs)['Line']['line_type']
endpoints = eval(outputs)['Line']['line_endpoint']
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15)
for idx, line in enumerate(lines):
try:
p0 = eval(line.split(' -- ')[0])
p1 = eval(line.split(' -- ')[-1])
if line_type[idx] == '--':
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
else:
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
ax.scatter(p0[0], p0[1], s=5, color = 'k')
ax.scatter(p1[0], p1[1], s=5, color = 'k')
except:
pass
for endpoint in endpoints:
label = endpoint.split(': ')[0]
(x, y) = eval(endpoint.split(': ')[1])
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
fontsize=5, fontweight='light')
plt.savefig(f'{output_path}/geo.jpg')
plt.close()
result.save(f"{output_path}/result_with_boxes.jpg")
## TODO
# new training loop:
## image -> vision encoder -> projection ->! txt_decoder -> embedding -> pool
# => alignment(text_pooling, image_pooling)
## text -> text encoder -> projection -> embedding -> pool
## cant let projection layer output into text decoder