| import torch |
| import torch.nn as nn |
|
|
| from transformers import ( |
| PreTrainedModel, |
| AutoModelForCausalLM, |
| CLIPVisionModel, |
| ) |
|
|
| from .configuration_gemma_clip_vlm import GemmaCLIPVLMConfig |
|
|
| class GemmaCLIPVLM(PreTrainedModel): |
| config_class = GemmaCLIPVLMConfig |
| base_model_prefix = "gemma_clip_vlm" |
|
|
| def __init__(self, config: GemmaCLIPVLMConfig): |
| super().__init__(config) |
|
|
| |
| self.vision = CLIPVisionModel.from_pretrained( |
| config.vision_model_name |
| ) |
|
|
| |
| self.llm = AutoModelForCausalLM.from_pretrained( |
| config.llm_model_name |
| ) |
|
|
| self.num_image_tokens = config.num_image_tokens |
|
|
| vision_dim = self.vision.config.hidden_size |
| llm_dim = self.llm.config.hidden_size |
|
|
| |
| self.projector = nn.Linear(vision_dim, llm_dim) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids, |
| attention_mask=None, |
| labels=None, |
| pixel_values=None, |
| **kwargs |
| ): |
| """ |
| LLaVA-style forward: |
| - <image> tokens already exist in input_ids |
| - their embeddings are REPLACED by projected vision embeddings |
| """ |
|
|
| |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) |
|
|
| |
| with torch.no_grad(): |
| vision_feats = self.vision( |
| pixel_values=pixel_values |
| ).last_hidden_state |
|
|
| |
| vision_feats = vision_feats / vision_feats.norm(dim=-1, keepdim=True) |
|
|
| |
| img_tokens = self.projector(vision_feats) |
| img_tokens = img_tokens.to(inputs_embeds.dtype) |
|
|
| |
| inputs_embeds[:, : self.num_image_tokens, :] = img_tokens |
|
|
| |
| return self.llm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| labels=labels, |
| use_cache=False, |
| ) |
|
|