|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, List, Union, Tuple |
|
|
from transformers import Qwen2VLTextModel, Qwen2VLTextConfig, Qwen2VLPreTrainedModel, PretrainedConfig |
|
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from PIL import Image, ImageOps |
|
|
from encoder import build_sam_vit_b, build_clip_l, MlpProjector |
|
|
from addict import Dict as ADict |
|
|
import os |
|
|
import math |
|
|
from data import ( |
|
|
format_messages, |
|
|
load_pil_images, |
|
|
text_encode, |
|
|
BasicImageTransform, |
|
|
dynamic_preprocess, |
|
|
re_match, |
|
|
process_image_with_refs, |
|
|
NoEOSTextStreamer, |
|
|
) |
|
|
from tqdm import tqdm |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
class DeepQwenVLConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration class for DeepQwenVL model. |
|
|
|
|
|
This config wraps both the Qwen2VL text config and DeepSeek vision config. |
|
|
When loading from a Qwen2-VL checkpoint, it will use the checkpoint's config |
|
|
directly for the text model. |
|
|
""" |
|
|
model_type = "deepqwen_vl" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
deepseek_vision_hidden_size: int = 2048, |
|
|
|
|
|
|
|
|
projector_type: str = "mlp", |
|
|
projector_input_dim: int = 2048, |
|
|
projector_output_dim: int = None, |
|
|
projector_hidden_dim: int = None, |
|
|
|
|
|
|
|
|
image_newline_dim: int = None, |
|
|
view_separator_dim: int = None, |
|
|
|
|
|
hidden_size: int = 1536, |
|
|
intermediate_size: int = 8960, |
|
|
num_hidden_layers: int = 28, |
|
|
num_attention_heads: int = 12, |
|
|
num_key_value_heads: int = 2, |
|
|
hidden_act: str = "silu", |
|
|
max_position_embeddings: int = 32768, |
|
|
initializer_range: float = 0.02, |
|
|
rms_norm_eps: float = 1e-6, |
|
|
use_cache: bool = True, |
|
|
tie_word_embeddings: bool = True, |
|
|
rope_theta: float = 1000000.0, |
|
|
attention_dropout: float = 0.0, |
|
|
vocab_size: int = 151936, |
|
|
|
|
|
bos_token_id: int = 151643, |
|
|
eos_token_id: int = 151645, |
|
|
pad_token_id: int = 151643, |
|
|
image_token_id: int = 151655, |
|
|
video_token_id: int = 151656, |
|
|
vision_start_token_id: int = 151652, |
|
|
vision_end_token_id: int = 151653, |
|
|
vision_token_id: int = 151654, |
|
|
|
|
|
rope_scaling: dict = None, |
|
|
|
|
|
**kwargs |
|
|
): |
|
|
super().__init__( |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
pad_token_id=pad_token_id, |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
self.deepseek_vision_hidden_size = deepseek_vision_hidden_size |
|
|
|
|
|
|
|
|
self.projector_type = projector_type |
|
|
self.projector_input_dim = projector_input_dim |
|
|
self.projector_output_dim = projector_output_dim if projector_output_dim else hidden_size |
|
|
self.projector_hidden_dim = projector_hidden_dim if projector_hidden_dim else self.projector_output_dim |
|
|
|
|
|
|
|
|
self.image_newline_dim = image_newline_dim if image_newline_dim else hidden_size |
|
|
self.view_separator_dim = view_separator_dim if view_separator_dim else hidden_size |
|
|
|
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
self.hidden_act = hidden_act |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.initializer_range = initializer_range |
|
|
self.rms_norm_eps = rms_norm_eps |
|
|
self.use_cache = use_cache |
|
|
self.rope_theta = rope_theta |
|
|
self.attention_dropout = attention_dropout |
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
|
|
|
self.image_token_id = image_token_id |
|
|
self.video_token_id = video_token_id |
|
|
self.vision_start_token_id = vision_start_token_id |
|
|
self.vision_end_token_id = vision_end_token_id |
|
|
self.vision_token_id = vision_token_id |
|
|
|
|
|
|
|
|
if rope_scaling is None: |
|
|
rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} |
|
|
self.rope_scaling = rope_scaling |
|
|
|
|
|
def to_text_config(self) -> Qwen2VLTextConfig: |
|
|
"""Convert to Qwen2VLTextConfig for the text model.""" |
|
|
return Qwen2VLTextConfig( |
|
|
hidden_size=self.hidden_size, |
|
|
intermediate_size=self.intermediate_size, |
|
|
num_hidden_layers=self.num_hidden_layers, |
|
|
num_attention_heads=self.num_attention_heads, |
|
|
num_key_value_heads=self.num_key_value_heads, |
|
|
hidden_act=self.hidden_act, |
|
|
max_position_embeddings=self.max_position_embeddings, |
|
|
initializer_range=self.initializer_range, |
|
|
rms_norm_eps=self.rms_norm_eps, |
|
|
use_cache=self.use_cache, |
|
|
tie_word_embeddings=self.tie_word_embeddings, |
|
|
rope_theta=self.rope_theta, |
|
|
attention_dropout=self.attention_dropout, |
|
|
vocab_size=self.vocab_size, |
|
|
bos_token_id=self.bos_token_id, |
|
|
eos_token_id=self.eos_token_id, |
|
|
pad_token_id=self.pad_token_id, |
|
|
rope_scaling=self.rope_scaling, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DeepQwenOutputWithPast(ModelOutput): |
|
|
last_hidden_state: torch.FloatTensor = None |
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
|
class DeepQwenCausalLMOutputWithPast(ModelOutput): |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
class VisionProjector(nn.Module): |
|
|
""" |
|
|
Vision projector with DeepSeek's pretrained layer + trainable adapter. |
|
|
|
|
|
Architecture: |
|
|
deepseek_proj: Linear(2048→1280) [FROZEN - loaded from DeepSeek checkpoint] |
|
|
SiLU activation |
|
|
norm: LayerNorm(1280) [TRAINABLE] |
|
|
adapter: Linear(1280→1536) [TRAINABLE] |
|
|
|
|
|
This preserves DeepSeek's learned vision-text alignment while adapting to Qwen's |
|
|
embedding space. Total 2 layers like LLaVA's MLP projector. |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim: int = 2048, hidden_dim: int = 1280, output_dim: int = 1536): |
|
|
super().__init__() |
|
|
|
|
|
self.deepseek_proj = nn.Linear(input_dim, hidden_dim) |
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_dim) |
|
|
self.adapter = nn.Linear(hidden_dim, output_dim) |
|
|
self._init_adapter_weights() |
|
|
|
|
|
def _init_adapter_weights(self): |
|
|
"""Initialize adapter weights. deepseek_proj will be loaded from checkpoint.""" |
|
|
nn.init.ones_(self.norm.weight) |
|
|
nn.init.zeros_(self.norm.bias) |
|
|
nn.init.normal_(self.adapter.weight, mean=0.0, std=0.01) |
|
|
nn.init.zeros_(self.adapter.bias) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.deepseek_proj(x) |
|
|
x = F.silu(x) |
|
|
x = self.norm(x) |
|
|
x = self.adapter(x) |
|
|
return x |
|
|
|
|
|
class DeepQwenVLPreTrainedModel(PreTrainedModel): |
|
|
config_class = DeepQwenVLConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_flash_attn = True |
|
|
_supports_sdpa = True |
|
|
_supports_static_cache = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
_keys_to_ignore_on_load_missing = [ |
|
|
"sam_model", |
|
|
"vision_model", |
|
|
"projector", |
|
|
"image_newline", |
|
|
"view_separator", |
|
|
] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights.""" |
|
|
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
|
|
|
|
|
|
class DeepQwenVLModel(Qwen2VLTextModel): |
|
|
""" |
|
|
DeepQwenVL Model that combines DeepSeek's vision encoders with Qwen2VL's text model. |
|
|
|
|
|
Accepts either: |
|
|
- A DeepQwenVLConfig |
|
|
- A Qwen2VLTextConfig (for compatibility with from_pretrained from Qwen checkpoints) |
|
|
- A generic PretrainedConfig (will extract necessary fields) |
|
|
""" |
|
|
config_class = DeepQwenVLConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
if isinstance(config, DeepQwenVLConfig): |
|
|
text_config = config.to_text_config() |
|
|
output_hidden_size = config.projector_output_dim |
|
|
vision_dim = config.deepseek_vision_hidden_size |
|
|
elif isinstance(config, Qwen2VLTextConfig): |
|
|
text_config = config |
|
|
output_hidden_size = config.hidden_size |
|
|
vision_dim = 2048 |
|
|
else: |
|
|
text_config = config |
|
|
output_hidden_size = getattr(config, 'hidden_size', 1536) |
|
|
vision_dim = getattr(config, 'deepseek_vision_hidden_size', 2048) |
|
|
|
|
|
super(DeepQwenVLModel, self).__init__(text_config) |
|
|
|
|
|
self.config = config |
|
|
self.output_hidden_size = output_hidden_size |
|
|
|
|
|
self.sam_model = build_sam_vit_b() |
|
|
self.vision_model = build_clip_l() |
|
|
|
|
|
self.deepseek_vision_dim = vision_dim |
|
|
self.deepseek_hidden_dim = 1280 |
|
|
|
|
|
self.projector = VisionProjector( |
|
|
input_dim=self.deepseek_vision_dim, |
|
|
hidden_dim=self.deepseek_hidden_dim, |
|
|
output_dim=output_hidden_size |
|
|
) |
|
|
|
|
|
embed_std = 1 / torch.sqrt(torch.tensor(output_hidden_size, dtype=torch.float32)) |
|
|
self.image_newline = nn.Parameter(torch.randn(output_hidden_size) * embed_std) |
|
|
self.view_separator = nn.Parameter(torch.randn(output_hidden_size) * embed_std) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
images_seq_mask: Optional[torch.FloatTensor] = None, |
|
|
images_spatial_crop: Optional[torch.FloatTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
sam_model = getattr(self, 'sam_model', None) |
|
|
vision_model = getattr(self, 'vision_model', None) |
|
|
|
|
|
should_process_images = ( |
|
|
sam_model is not None |
|
|
and images is not None |
|
|
and images_seq_mask is not None |
|
|
and (input_ids.shape[1] != 1 or self.training) |
|
|
and torch.sum(images[0][1]).item() != 0 |
|
|
) |
|
|
|
|
|
if should_process_images: |
|
|
idx = 0 |
|
|
for image, crop_shape in zip(images, images_spatial_crop): |
|
|
images_in_this_batch = [] |
|
|
patches = image[0] |
|
|
image_ori = image[1] |
|
|
|
|
|
if torch.sum(patches).item() != 0: |
|
|
|
|
|
with torch.no_grad(): |
|
|
local_features_1 = sam_model(patches) |
|
|
local_features_2 = vision_model(patches, local_features_1) |
|
|
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
|
|
local_features = local_features.detach() |
|
|
local_features = self.projector(local_features) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
global_features_1 = sam_model(image_ori) |
|
|
global_features_2 = vision_model(image_ori, global_features_1) |
|
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
|
|
global_features = global_features.detach() |
|
|
global_features = self.projector(global_features) |
|
|
|
|
|
|
|
|
_, hw, n_dim = global_features.shape |
|
|
h = w = int(hw ** 0.5) |
|
|
_2, hw2, n_dim2 = local_features.shape |
|
|
h2 = w2 = int(hw2 ** 0.5) |
|
|
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] |
|
|
|
|
|
global_features = global_features.view(h, w, n_dim) |
|
|
global_features = torch.cat( |
|
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
|
|
) |
|
|
global_features = global_features.view(-1, n_dim) |
|
|
|
|
|
local_features = local_features.view( |
|
|
height_crop_num, width_crop_num, h2, w2, n_dim2 |
|
|
).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) |
|
|
local_features = torch.cat( |
|
|
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 |
|
|
) |
|
|
local_features = local_features.view(-1, n_dim2) |
|
|
|
|
|
global_local_features = torch.cat([local_features, global_features, self.view_separator[None, :]], dim=0) |
|
|
images_in_this_batch.append(global_local_features) |
|
|
else: |
|
|
|
|
|
with torch.no_grad(): |
|
|
global_features_1 = sam_model(image_ori) |
|
|
global_features_2 = vision_model(image_ori, global_features_1) |
|
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
|
|
global_features = global_features.detach() |
|
|
global_features = self.projector(global_features) |
|
|
|
|
|
_, hw, n_dim = global_features.shape |
|
|
h = w = int(hw ** 0.5) |
|
|
global_features = global_features.view(h, w, n_dim) |
|
|
global_features = torch.cat( |
|
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
|
|
) |
|
|
global_features = global_features.view(-1, n_dim) |
|
|
global_local_features = torch.cat([global_features, self.view_separator[None, :]], dim=0) |
|
|
images_in_this_batch.append(global_local_features) |
|
|
|
|
|
if images_in_this_batch: |
|
|
images_in_this_batch = torch.cat(images_in_this_batch, dim=0) |
|
|
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) |
|
|
idx += 1 |
|
|
|
|
|
outputs = super().forward( |
|
|
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, |
|
|
output_attentions=output_attentions, output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, cache_position=cache_position |
|
|
) |
|
|
|
|
|
return DeepQwenOutputWithPast( |
|
|
last_hidden_state=outputs.last_hidden_state, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) if return_dict else outputs.to_tuple() |
|
|
|
|
|
|
|
|
class DeepQwenVLForCausalLM(DeepQwenVLModel, GenerationMixin): |
|
|
""" |
|
|
DeepQwenVL Model for causal language modeling with vision capabilities. |
|
|
|
|
|
Combines DeepSeek's vision encoders (SAM + CLIP) with Qwen2VL's text model. |
|
|
""" |
|
|
config_class = DeepQwenVLConfig |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
_keys_to_ignore_on_load_missing = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Initialize the model. |
|
|
|
|
|
Args: |
|
|
config: Can be DeepQwenVLConfig, Qwen2VLTextConfig, or a generic config |
|
|
from a Qwen2-VL checkpoint. |
|
|
""" |
|
|
super().__init__(config) |
|
|
|
|
|
hidden_size = getattr(config, 'hidden_size', 1536) |
|
|
vocab_size = getattr(config, 'vocab_size', 151936) |
|
|
|
|
|
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return getattr(self, 'lm_head', None) |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
images_seq_mask: Optional[torch.FloatTensor] = None, |
|
|
images_spatial_crop: Optional[torch.FloatTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
|
|
|
|
outputs = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
position_ids = position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
images=images, |
|
|
images_seq_mask=images_seq_mask, |
|
|
images_spatial_crop=images_spatial_crop, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states) |
|
|
logits = logits.float() |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
|
|
|
return DeepQwenCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
inputs_embeds=None, |
|
|
cache_position=None, |
|
|
position_ids=None, |
|
|
images=None, |
|
|
images_seq_mask=None, |
|
|
images_spatial_crop=None, |
|
|
**kwargs, |
|
|
): |
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
cache_position=cache_position, |
|
|
position_ids=position_ids, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
model_inputs["images"] = images |
|
|
model_inputs["images_seq_mask"] = images_seq_mask |
|
|
model_inputs["images_spatial_crop"] = images_spatial_crop |
|
|
model_inputs["position_ids"] = None |
|
|
|
|
|
|
|
|
if cache_position is not None and cache_position[0] != 0: |
|
|
model_inputs["images"] = None |
|
|
model_inputs["images_seq_mask"] = None |
|
|
model_inputs["images_spatial_crop"] = None |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
def reinitialize_projector(self, vis_mlp=None, device=None, dtype=None): |
|
|
""" |
|
|
Reinitialize the projector, image_newline, and view_separator. |
|
|
Call this after from_pretrained when loading from a Qwen checkpoint. |
|
|
""" |
|
|
if device is None: |
|
|
for param in self.parameters(): |
|
|
if param.device.type != 'meta': |
|
|
device = param.device |
|
|
break |
|
|
if device is None: |
|
|
device = 'cpu' |
|
|
if dtype is None: |
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
input_dim = self.deepseek_vision_dim |
|
|
output_dim = self.output_hidden_size |
|
|
|
|
|
if vis_mlp is not None: |
|
|
self.projector = VisionProjector(input_dim=input_dim, output_dim=output_dim).to(device=device, dtype=dtype) |
|
|
|
|
|
else: |
|
|
self.projector = nn.Linear(in_features=input_dim, out_features=output_dim).to(device=device, dtype=dtype) |
|
|
nn.init.normal_(self.projector.weight, mean=0.0, std=0.01) |
|
|
if self.projector.bias is not None: |
|
|
nn.init.zeros_(self.projector.bias) |
|
|
|
|
|
embed_std = 1 / torch.sqrt(torch.tensor(output_dim, dtype=torch.float32)) |
|
|
self.image_newline = nn.Parameter( |
|
|
torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item() |
|
|
) |
|
|
self.view_separator = nn.Parameter( |
|
|
torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item() |
|
|
) |
|
|
|
|
|
print(f"Projector reinitialized on {device} with dtype {dtype}") |
|
|
|
|
|
def load_pretrained_vision(self, pretrained_path: str): |
|
|
try: |
|
|
from safetensors import safe_open |
|
|
except ImportError: |
|
|
raise ImportError("Please install safetensors to load the pretrained vision model.") |
|
|
|
|
|
assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist." |
|
|
|
|
|
vision_weights = {} |
|
|
with safe_open(f"{pretrained_path}/model-00001-of-000001.safetensors", framework="pt", device="cpu") as f: |
|
|
for k in f.keys(): |
|
|
vision_weights[k] = f.get_tensor(k) |
|
|
|
|
|
prefixes = { |
|
|
"sam_model": "model.sam_model.", |
|
|
"vision_model": "model.vision_model.", |
|
|
} |
|
|
|
|
|
try: |
|
|
for p in prefixes.keys(): |
|
|
state_dict = {} |
|
|
|
|
|
for k, v in vision_weights.items(): |
|
|
if k.startswith(prefixes[p]): |
|
|
new_key = k[len(prefixes[p]):] |
|
|
state_dict[new_key] = v |
|
|
|
|
|
getattr(self, p).load_state_dict(state_dict, strict=False) |
|
|
|
|
|
print("Pretrained vision model loaded successfully.") |
|
|
except Exception as e: |
|
|
print("Error loading pretrained vision model:", e) |
|
|
raise e |
|
|
|
|
|
def load_deepseek_projector(self, pretrained_path: str): |
|
|
""" |
|
|
Load DeepSeek's projector weights into the deepseek_proj layer. |
|
|
|
|
|
DeepSeek checkpoint has: |
|
|
- projector.weight: shape (1280, 2048) |
|
|
- projector.bias: shape (1280,) |
|
|
|
|
|
These get loaded into self.projector.deepseek_proj |
|
|
""" |
|
|
try: |
|
|
from safetensors import safe_open |
|
|
except ImportError: |
|
|
raise ImportError("Please install safetensors to load DeepSeek projector.") |
|
|
|
|
|
assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist." |
|
|
|
|
|
|
|
|
safetensor_files = [f for f in os.listdir(pretrained_path) if f.endswith('.safetensors')] |
|
|
if not safetensor_files: |
|
|
raise FileNotFoundError(f"No safetensors files found in {pretrained_path}") |
|
|
|
|
|
safetensor_path = os.path.join(pretrained_path, safetensor_files[0]) |
|
|
|
|
|
projector_weights = {} |
|
|
with safe_open(safetensor_path, framework="pt", device="cpu") as f: |
|
|
for k in f.keys(): |
|
|
if 'projector' in k: |
|
|
projector_weights[k] = f.get_tensor(k) |
|
|
|
|
|
|
|
|
if 'projector.weight' in projector_weights: |
|
|
self.projector.deepseek_proj.weight.data = projector_weights['projector.weight'] |
|
|
self.projector.deepseek_proj.bias.data = projector_weights['projector.bias'] |
|
|
print(f"Loaded DeepSeek projector weights: {self.projector.deepseek_proj.weight.shape}") |
|
|
print(f" Weight mean: {self.projector.deepseek_proj.weight.mean().item():.6f}") |
|
|
print(f" Weight std: {self.projector.deepseek_proj.weight.std().item():.6f}") |
|
|
elif 'model.projector.weight' in projector_weights: |
|
|
self.projector.deepseek_proj.weight.data = projector_weights['model.projector.weight'] |
|
|
self.projector.deepseek_proj.bias.data = projector_weights['model.projector.bias'] |
|
|
print(f"Loaded DeepSeek projector weights (model. prefix)") |
|
|
else: |
|
|
print(f"Warning: Could not find projector weights. Available keys: {list(projector_weights.keys())}") |
|
|
|
|
|
def disable_torch_init(self): |
|
|
""" |
|
|
Disable the redundant torch default initialization to accelerate model creation. |
|
|
""" |
|
|
import torch |
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
tokenizer, |
|
|
prompt='', |
|
|
image_file='', |
|
|
output_path = '', |
|
|
base_size=1024, |
|
|
image_size=640, |
|
|
crop_mode=True, |
|
|
test_compress=False, |
|
|
save_results=False, |
|
|
eval_mode=False |
|
|
): |
|
|
self.disable_torch_init() |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
os.makedirs(f'{output_path}/images', exist_ok=True) |
|
|
conversation = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": f"{image_file}", |
|
|
}, |
|
|
{"type": "text", "text": f"{prompt}"}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
formatted_prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
patch_size = 16 |
|
|
downsample_ratio = 4 |
|
|
images = load_pil_images(conversation) |
|
|
|
|
|
valid_img_tokens = 0 |
|
|
ratio = 1 |
|
|
|
|
|
image_draw = images[0].copy() |
|
|
|
|
|
w,h = image_draw.size |
|
|
ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) |
|
|
|
|
|
|
|
|
image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) |
|
|
images_seq_mask = [] |
|
|
|
|
|
image_token = '<|image_pad|>' |
|
|
image_token_id = 151655 |
|
|
text_splits = formatted_prompt.split(image_token) |
|
|
|
|
|
images_list, images_crop_list, images_seq_mask = [], [], [] |
|
|
tokenized_str = [] |
|
|
images_spatial_crop = [] |
|
|
for text_sep, image in zip(text_splits, images): |
|
|
|
|
|
tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) |
|
|
tokenized_str += tokenized_sep |
|
|
images_seq_mask += [False] * len(tokenized_sep) |
|
|
|
|
|
if crop_mode: |
|
|
|
|
|
if image.size[0] <= 640 and image.size[1] <= 640: |
|
|
crop_ratio = [1, 1] |
|
|
|
|
|
else: |
|
|
if crop_mode: |
|
|
images_crop_raw, crop_ratio = dynamic_preprocess(image) |
|
|
else: |
|
|
crop_ratio = [1, 1] |
|
|
|
|
|
global_view = ImageOps.pad(image, (base_size, base_size), |
|
|
color=tuple(int(x * 255) for x in image_transform.mean)) |
|
|
|
|
|
if base_size == 1024: |
|
|
valid_img_tokens += int(256 * ratio) |
|
|
elif base_size == 1280: |
|
|
valid_img_tokens += int(400 * ratio) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_list.append(image_transform(global_view).to(torch.bfloat16)) |
|
|
|
|
|
|
|
|
|
|
|
width_crop_num, height_crop_num = crop_ratio |
|
|
|
|
|
images_spatial_crop.append([width_crop_num, height_crop_num]) |
|
|
|
|
|
|
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
"""process the local views""" |
|
|
|
|
|
for i in range(len(images_crop_raw)): |
|
|
images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) |
|
|
|
|
|
if image_size == 640: |
|
|
valid_img_tokens += len(images_crop_list) * 100 |
|
|
|
|
|
num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
|
|
num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) |
|
|
|
|
|
"""add image tokens""" |
|
|
|
|
|
tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base |
|
|
tokenized_image += [image_token_id] |
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( |
|
|
num_queries * height_crop_num) |
|
|
tokenized_str += tokenized_image |
|
|
images_seq_mask += [True] * len(tokenized_image) |
|
|
|
|
|
|
|
|
else: |
|
|
"""process the global view""" |
|
|
if image_size <= 640: |
|
|
image = image.resize((image_size, image_size)) |
|
|
global_view = ImageOps.pad(image, (image_size, image_size), |
|
|
color=tuple(int(x * 255) for x in image_transform.mean)) |
|
|
images_list.append(image_transform(global_view).to(torch.bfloat16)) |
|
|
|
|
|
if base_size == 1024: |
|
|
valid_img_tokens += int(256 * ratio) |
|
|
elif base_size == 1280: |
|
|
valid_img_tokens += int(400 * ratio) |
|
|
elif base_size == 640: |
|
|
valid_img_tokens += int(100 * 1) |
|
|
elif base_size == 512: |
|
|
valid_img_tokens += int(64 * 1) |
|
|
|
|
|
width_crop_num, height_crop_num = 1, 1 |
|
|
|
|
|
images_spatial_crop.append([width_crop_num, height_crop_num]) |
|
|
|
|
|
|
|
|
"""add image tokens""" |
|
|
num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
|
|
|
|
|
tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries |
|
|
tokenized_image += [image_token_id] |
|
|
|
|
|
|
|
|
tokenized_str += tokenized_image |
|
|
images_seq_mask += [True] * len(tokenized_image) |
|
|
|
|
|
|
|
|
"""process the last text split""" |
|
|
tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) |
|
|
tokenized_str += tokenized_sep |
|
|
images_seq_mask += [False] * len(tokenized_sep) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.LongTensor(tokenized_str) |
|
|
|
|
|
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) |
|
|
|
|
|
if len(images_list) == 0: |
|
|
images_ori = torch.zeros((1, 3, image_size, image_size)) |
|
|
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) |
|
|
images_crop = torch.zeros((1, 3, base_size, base_size)) |
|
|
|
|
|
else: |
|
|
images_ori = torch.stack(images_list, dim=0) |
|
|
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) |
|
|
if images_crop_list: |
|
|
images_crop = torch.stack(images_crop_list, dim=0) |
|
|
else: |
|
|
images_crop = torch.zeros((1, 3, base_size, base_size)) |
|
|
|
|
|
|
|
|
|
|
|
if not eval_mode: |
|
|
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) |
|
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
with torch.no_grad(): |
|
|
output_ids = self.generate( |
|
|
input_ids.unsqueeze(0).cuda(), |
|
|
images=[(images_crop.cuda(), images_ori.cuda())], |
|
|
images_seq_mask=images_seq_mask.unsqueeze(0).cuda(), |
|
|
images_spatial_crop=images_spatial_crop, |
|
|
temperature=0.5, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
streamer=streamer, |
|
|
max_new_tokens=8192, |
|
|
no_repeat_ngram_size=20, |
|
|
use_cache=True |
|
|
) |
|
|
else: |
|
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
with torch.no_grad(): |
|
|
output_ids = self.generate( |
|
|
input_ids.unsqueeze(0).cuda(), |
|
|
images=[(images_crop.cuda(), images_ori.cuda())], |
|
|
images_seq_mask=images_seq_mask.unsqueeze(0).cuda(), |
|
|
images_spatial_crop=images_spatial_crop, |
|
|
temperature=0.5, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
max_new_tokens=8192, |
|
|
no_repeat_ngram_size=35, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
has_image = any( |
|
|
(isinstance(item, dict) and item.get('type') == 'image') |
|
|
for msg in conversation |
|
|
for item in (msg.get('content', []) if isinstance(msg.get('content'), list) else []) |
|
|
) |
|
|
|
|
|
if has_image and eval_mode: |
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) |
|
|
|
|
|
stop_str = tokenizer.eos_token or '<|im_end|>' |
|
|
if outputs.endswith(stop_str): |
|
|
outputs = outputs[:-len(stop_str)] |
|
|
outputs = outputs.strip() |
|
|
|
|
|
return outputs |
|
|
|
|
|
if has_image and test_compress: |
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) |
|
|
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) |
|
|
print('='*50) |
|
|
print('image size: ', (w, h)) |
|
|
print('valid image tokens: ', int(valid_img_tokens)) |
|
|
print('output texts tokens (valid): ', pure_texts_outputs_token_length) |
|
|
print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) |
|
|
print('='*50) |
|
|
|
|
|
|
|
|
if has_image and save_results: |
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) |
|
|
|
|
|
stop_str = tokenizer.eos_token or '<|im_end|>' |
|
|
|
|
|
print('='*15 + 'save results:' + '='*15) |
|
|
|
|
|
if outputs.endswith(stop_str): |
|
|
outputs = outputs[:-len(stop_str)] |
|
|
outputs = outputs.strip() |
|
|
|
|
|
matches_ref, matches_images, mathes_other = re_match(outputs) |
|
|
result = process_image_with_refs(image_draw, matches_ref, output_path) |
|
|
|
|
|
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): |
|
|
outputs = outputs.replace(a_match_image, ' + '.jpg)\n') |
|
|
|
|
|
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): |
|
|
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') |
|
|
|
|
|
with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: |
|
|
afile.write(outputs) |
|
|
|
|
|
if 'line_type' in outputs: |
|
|
import matplotlib.pyplot as plt |
|
|
lines = eval(outputs)['Line']['line'] |
|
|
|
|
|
line_type = eval(outputs)['Line']['line_type'] |
|
|
endpoints = eval(outputs)['Line']['line_endpoint'] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(3,3), dpi=200) |
|
|
ax.set_xlim(-15, 15) |
|
|
ax.set_ylim(-15, 15) |
|
|
|
|
|
for idx, line in enumerate(lines): |
|
|
try: |
|
|
p0 = eval(line.split(' -- ')[0]) |
|
|
p1 = eval(line.split(' -- ')[-1]) |
|
|
|
|
|
if line_type[idx] == '--': |
|
|
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') |
|
|
else: |
|
|
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') |
|
|
|
|
|
ax.scatter(p0[0], p0[1], s=5, color = 'k') |
|
|
ax.scatter(p1[0], p1[1], s=5, color = 'k') |
|
|
except: |
|
|
pass |
|
|
|
|
|
for endpoint in endpoints: |
|
|
|
|
|
label = endpoint.split(': ')[0] |
|
|
(x, y) = eval(endpoint.split(': ')[1]) |
|
|
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', |
|
|
fontsize=5, fontweight='light') |
|
|
|
|
|
|
|
|
plt.savefig(f'{output_path}/geo.jpg') |
|
|
plt.close() |
|
|
|
|
|
result.save(f"{output_path}/result_with_boxes.jpg") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|