vlm-step2 / modeling_gemma_clip_vlm.py
harsh13333's picture
End of training
996b939 verified
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
AutoModelForCausalLM,
CLIPVisionModel,
)
from .configuration_gemma_clip_vlm import GemmaCLIPVLMConfig
class GemmaCLIPVLM(PreTrainedModel):
config_class = GemmaCLIPVLMConfig
base_model_prefix = "gemma_clip_vlm"
def __init__(self, config: GemmaCLIPVLMConfig):
super().__init__(config)
# ---- Vision encoder ----
self.vision = CLIPVisionModel.from_pretrained(
config.vision_model_name
)
# ---- Language model ----
self.llm = AutoModelForCausalLM.from_pretrained(
config.llm_model_name
)
self.num_image_tokens = config.num_image_tokens
vision_dim = self.vision.config.hidden_size
llm_dim = self.llm.config.hidden_size
# ---- Projector (ONLY trainable in Step 1) ----
self.projector = nn.Linear(vision_dim, llm_dim)
# Required by HF
self.post_init()
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
pixel_values=None,
**kwargs
):
"""
LLaVA-style forward:
- <image> tokens already exist in input_ids
- their embeddings are REPLACED by projected vision embeddings
"""
# ---- Text embeddings (includes <image> tokens) ----
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
# ---- Vision forward (frozen in Step 1) ----
with torch.no_grad():
vision_feats = self.vision(
pixel_values=pixel_values
).last_hidden_state # [B, N_img, vision_dim]
# Normalize (important for stability)
vision_feats = vision_feats / vision_feats.norm(dim=-1, keepdim=True)
# ---- Project to LLM hidden ----
img_tokens = self.projector(vision_feats)
img_tokens = img_tokens.to(inputs_embeds.dtype)
# ---- LLaVA-style replacement ----
inputs_embeds[:, : self.num_image_tokens, :] = img_tokens
# ---- Run LLM ----
return self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
use_cache=False,
)