"""Villanova VLM Model for HuggingFace. This is a standalone model file for use with trust_remote_code=True. It contains no imports from aithlas_trainer to ensure self-containment. """ from typing import Any import torch import torch.nn as nn from transformers import AutoModelForCausalLM, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_villanova import VillanovaConfig class ViTEncoder(nn.Module): """Vision encoder for Villanova VLM using OpenCLIP. Supports both: - OpenCLIP CLIPA models (ViT-L-14-CLIPA-336) with direct visual transformer - SigLIP models (ViT-L-16-SigLIP-384) wrapped via TimmModel The model is loaded from OpenCLIP pretrained weights (not from safetensors). IMPORTANT: Uses manual forward pass to match training code exactly. Do NOT use output_tokens=True as it produces different outputs. """ def __init__(self, config: dict[str, Any]) -> None: super().__init__() self.hidden_size = config.get("hidden_size", 1024) # Support both old key (model_name) and new key (encoder_name) self.model_name = config.get("encoder_name") or config.get("model_name", "ViT-L-14-CLIPA-336") self.pretrained = config.get("pretrained", "datacomp1b") # Placeholder - will be loaded lazily self._clip_model: nn.Module | None = None self._is_siglip: bool = "SigLIP" in self.model_name def _ensure_clip_loaded(self) -> None: """Load OpenCLIP model if not already loaded.""" if self._clip_model is None: import open_clip model, _, _ = open_clip.create_model_and_transforms( self.model_name, pretrained=self.pretrained, ) # Use model.visual directly self._clip_model = model.visual self._clip_model.eval() # Freeze all parameters for param in self._clip_model.parameters(): param.requires_grad = False def _forward_siglip(self, pixel_values: torch.Tensor) -> torch.Tensor: """Forward pass for SigLIP models (TimmModel wrapper).""" visual = self._clip_model trunk = visual.trunk # VisionTransformer from timm # Patch embedding x = trunk.patch_embed(pixel_values) # (B, num_patches, hidden_dim) # Add positional embedding (SigLIP may or may not have cls_token) if trunk.cls_token is not None and trunk.cls_token.numel() > 0: cls_tokens = trunk.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([cls_tokens, x], dim=1) # Add positional embedding x = x + trunk.pos_embed # Optional: position dropout (usually identity) x = trunk.pos_drop(x) # Optional: patch dropout (usually identity) if hasattr(trunk, "patch_drop") and trunk.patch_drop is not None: x = trunk.patch_drop(x) # Optional: pre-norm (some models have this) if hasattr(trunk, "norm_pre") and trunk.norm_pre is not None: x = trunk.norm_pre(x) # Apply transformer blocks x = trunk.blocks(x) # Final norm x = trunk.norm(x) # Remove CLS token if present, return only patch tokens if trunk.cls_token is not None and trunk.cls_token.numel() > 0: patch_tokens = x[:, 1:, :] else: patch_tokens = x return patch_tokens def _forward_clipa(self, pixel_values: torch.Tensor) -> torch.Tensor: """Forward pass for CLIPA models (standard OpenCLIP).""" visual = self._clip_model # Step 1: Get patch embeddings via conv1 x = visual.conv1(pixel_values) # (B, hidden_dim, grid, grid) x = x.reshape(x.shape[0], x.shape[1], -1) # (B, hidden_dim, num_patches) x = x.permute(0, 2, 1) # (B, num_patches, hidden_dim) # Step 2: Add positional embeddings (including CLS position) if hasattr(visual, "positional_embedding"): # OpenCLIP style: add CLS token and positional embeddings cls_pos = visual.class_embedding.expand(x.shape[0], 1, -1) x = torch.cat([cls_pos, x], dim=1) x = x + visual.positional_embedding.unsqueeze(0) elif hasattr(visual, "pos_embed"): # Alternative style x = x + visual.pos_embed[:, 1:, :] # Step 3: Apply layer norm before transformer x = visual.ln_pre(x) # Step 4: Apply transformer (expects seq_len first) x = x.permute(1, 0, 2) # (seq_len, B, hidden_dim) x = visual.transformer(x) x = x.permute(1, 0, 2) # (B, seq_len, hidden_dim) # Step 5: Apply post-transformer layer norm (CRITICAL for correct output scale) x = visual.ln_post(x) # Step 6: Remove CLS token, return only patch tokens patch_tokens = x[:, 1:, :] return patch_tokens def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Encode images to visual embeddings. Uses MANUAL forward pass through OpenCLIP vision encoder to match training code exactly. This is critical for correct inference. Args: pixel_values: Image tensor (batch_size, 3, H, W) Returns: Visual embeddings (batch_size, num_patches, hidden_size) """ self._ensure_clip_loaded() visual = self._clip_model # Convert model to input dtype if needed (critical for matching training behavior) input_dtype = pixel_values.dtype model_dtype = next(visual.parameters()).dtype if model_dtype != input_dtype: self._clip_model = visual.to(dtype=input_dtype) visual = self._clip_model # Move model to same device as input if next(visual.parameters()).device != pixel_values.device: self._clip_model = visual.to(pixel_values.device) visual = self._clip_model with torch.no_grad(): if self._is_siglip: return self._forward_siglip(pixel_values) else: return self._forward_clipa(pixel_values) class MLPProjector(nn.Module): """MLP Projector to map vision features to LLM embedding space. 2-layer MLP with GELU activation (no output LayerNorm by default). Structure matches the VillanovaVLM training checkpoint format: - mlp.0: Linear(input_size, hidden_size) - mlp.1: GELU (no params) - mlp.2: Linear(hidden_size, output_size) - output_norm: Identity() by default (no LayerNorm, like LLaVA) NOTE: LLaVA does NOT use LayerNorm on projector output. LLM embeddings have std≈0.008, LayerNorm forces std≈1, causing 140x scale mismatch. """ def __init__(self, config: dict[str, Any]) -> None: super().__init__() input_size = config.get("input_size", 1024) output_size = config.get("output_size", 2048) hidden_size = config.get("hidden_size", output_size) use_layer_norm = config.get("use_layer_norm", False) bias = config.get("bias", True) # Scale factor for output. Default 1.0 to match training behavior. # Note: If training used output_scale, it should be set in config. self.output_scale = config.get("output_scale", 1.0) # Build MLP layers to match checkpoint structure self.mlp = nn.Sequential( nn.Linear(input_size, hidden_size, bias=bias), nn.GELU(), nn.Linear(hidden_size, output_size, bias=bias), ) # Output normalization (separate from mlp to match checkpoint keys) if use_layer_norm: self.output_norm = nn.LayerNorm(output_size) else: self.output_norm = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """Project vision features to LLM space.""" x = self.mlp(x) x = self.output_norm(x) # Scale to match LLM embedding magnitude if self.output_scale != 1.0: x = x * self.output_scale return x class VillanovaVLMForConditionalGeneration(PreTrainedModel): """Villanova Vision-Language Model for conditional generation. Combines ViT-L-14-CLIPA-336 vision encoder, 2-layer MLP projector, and Villanova 2B language model. Example: >>> from transformers import AutoModelForImageTextToText, AutoProcessor >>> model = AutoModelForImageTextToText.from_pretrained( ... "VillanovaAI/Villanova-2B-VL-2512-Preview", ... trust_remote_code=True, ... ) >>> processor = AutoProcessor.from_pretrained( ... "VillanovaAI/Villanova-2B-VL-2512-Preview", ... trust_remote_code=True, ... ) """ config_class = VillanovaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MLPProjector"] def __init__(self, config: VillanovaConfig) -> None: super().__init__(config) # Vision encoder self.vision_encoder = ViTEncoder(config.vision_config) # Projector self.projector = MLPProjector(config.projector_config) # Language model (will be loaded separately) self.language_model: PreTrainedModel | None = None # Image token index self.image_token_index = config.image_token_index self.post_init() def get_input_embeddings(self) -> nn.Module | None: """Get input embeddings from language model.""" if self.language_model is not None: return self.language_model.get_input_embeddings() return None def set_input_embeddings(self, value: nn.Module) -> None: """Set input embeddings in language model.""" if self.language_model is not None: self.language_model.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Module | None: """Get output embeddings from language model.""" if self.language_model is not None: return self.language_model.get_output_embeddings() return None def set_output_embeddings(self, new_embeddings: nn.Module) -> None: """Set output embeddings in language model.""" if self.language_model is not None: self.language_model.set_output_embeddings(new_embeddings) def _merge_input_ids_with_image_features( self, input_ids: torch.Tensor, image_features: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Merge text embeddings with image features at token positions. This uses the EXPANSION approach (like LLaVA): a single token in the input is replaced with all 576 visual feature tokens. The sequence length increases by (num_patches - 1). For training compatibility, we expand the single token to num_patches copies, then replace each with the corresponding visual feature. """ batch_size = input_ids.shape[0] num_patches = image_features.shape[1] # Get text embeddings text_embeddings = self.get_input_embeddings()(input_ids) # Find image token positions image_token_mask = input_ids == self.image_token_index new_embeddings_list = [] new_attention_mask_list = [] if attention_mask is not None else None for b in range(batch_size): image_positions = torch.where(image_token_mask[b])[0] num_image_tokens = len(image_positions) if num_image_tokens == 0: # No image tokens - keep original embeddings new_embeddings_list.append(text_embeddings[b]) if attention_mask is not None: new_attention_mask_list.append(attention_mask[b]) elif num_image_tokens == 1: # Single token - expand to num_patches visual features pos = image_positions[0].item() before = text_embeddings[b, :pos] after = text_embeddings[b, pos + 1:] # Insert all visual features at the single position merged = torch.cat([before, image_features[b], after], dim=0) new_embeddings_list.append(merged) if attention_mask is not None: mask_before = attention_mask[b, :pos] mask_after = attention_mask[b, pos + 1:] image_mask = torch.ones(num_patches, dtype=attention_mask.dtype, device=attention_mask.device) merged_mask = torch.cat([mask_before, image_mask, mask_after], dim=0) new_attention_mask_list.append(merged_mask) else: # Multiple tokens - replace each with corresponding visual feature # This matches the training behavior when tokens are pre-expanded output = text_embeddings[b].clone() actual_patches = min(num_patches, num_image_tokens) for i in range(actual_patches): pos = image_positions[i].item() output[pos] = image_features[b, i] new_embeddings_list.append(output) if attention_mask is not None: new_attention_mask_list.append(attention_mask[b]) # Pad to same length max_len = max(e.shape[0] for e in new_embeddings_list) padded_embeddings = torch.zeros( batch_size, max_len, text_embeddings.shape[-1], dtype=text_embeddings.dtype, device=text_embeddings.device ) for b, emb in enumerate(new_embeddings_list): padded_embeddings[b, :emb.shape[0]] = emb padded_attention_mask = None if new_attention_mask_list is not None: padded_attention_mask = torch.zeros( batch_size, max_len, dtype=attention_mask.dtype, device=attention_mask.device ) for b, mask in enumerate(new_attention_mask_list): padded_attention_mask[b, :mask.shape[0]] = mask return padded_embeddings, padded_attention_mask def forward( self, input_ids: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, past_key_values: tuple | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs: Any, ) -> CausalLMOutputWithPast | tuple: """Forward pass.""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.language_model is None: raise RuntimeError("Language model not initialized") # Process image if provided if pixel_values is not None and inputs_embeds is None: image_features = self.vision_encoder(pixel_values) # Cast to projector dtype (vision encoder may output float32) image_features = image_features.to(self.projector.mlp[0].weight.dtype) image_features = self.projector(image_features) inputs_embeds, attention_mask = self._merge_input_ids_with_image_features( input_ids, image_features, attention_mask ) input_ids = None return self.language_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=labels, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) def generate( self, input_ids: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, max_new_tokens: int = 256, do_sample: bool = False, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, **kwargs: Any, ) -> torch.Tensor: """Generate text conditioned on image and prompt.""" if self.language_model is None: raise RuntimeError("Language model not initialized") if pixel_values is not None: image_features = self.vision_encoder(pixel_values) # Cast to projector dtype (vision encoder may output float32) image_features = image_features.to(self.projector.mlp[0].weight.dtype) image_features = self.projector(image_features) inputs_embeds, attention_mask = self._merge_input_ids_with_image_features( input_ids, image_features, attention_mask ) # Get token IDs from text_config or kwargs text_config = self.config.text_config pad_token_id = kwargs.pop("pad_token_id", None) or getattr(text_config, "pad_token_id", None) eos_token_id = kwargs.pop("eos_token_id", None) or getattr(text_config, "eos_token_id", None) return self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, pad_token_id=pad_token_id, eos_token_id=eos_token_id, **kwargs, ) return self.language_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, **kwargs, ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, *model_args: Any, config: VillanovaConfig | None = None, torch_dtype: torch.dtype | str | None = None, device_map: str | dict | None = None, **kwargs: Any, ) -> "VillanovaVLMForConditionalGeneration": """Load pretrained model.""" from pathlib import Path from safetensors.torch import load_file from transformers import AutoConfig # Remove trust_remote_code from kwargs to avoid passing it twice kwargs.pop("trust_remote_code", None) # Handle dtype/torch_dtype - newer transformers uses 'dtype' instead of 'torch_dtype' if torch_dtype is None: torch_dtype = kwargs.pop("dtype", None) else: kwargs.pop("dtype", None) # Remove if both were passed # Load config if config is None: config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True, **kwargs, ) # Handle torch_dtype string conversion if torch_dtype is not None: if isinstance(torch_dtype, str): torch_dtype = getattr(torch, torch_dtype.replace("torch.", "")) # Create model model = cls(config) # Create LLM from text_config # Get the text config dict text_config_dict = config.text_config.to_dict() if hasattr(config.text_config, "to_dict") else dict(config.text_config) # Check for nested text_config (used in VillanovaVLM training format) if "text_config" in text_config_dict and isinstance(text_config_dict["text_config"], dict): # Use the nested text_config which contains the actual LLM config llm_config_dict = dict(text_config_dict["text_config"]) else: llm_config_dict = text_config_dict # Get model type from config to determine which model class to use model_type = llm_config_dict.pop("model_type", "llama") # Remove non-config keys for key in ["_name_or_path", "transformers_version", "torch_dtype", "dtype"]: llm_config_dict.pop(key, None) # Create the LLM config and model from transformers import AutoConfig as HFAutoConfig, AutoModelForCausalLM as HFAutoModelForCausalLM llm_config = HFAutoConfig.for_model(model_type, **llm_config_dict) model.language_model = HFAutoModelForCausalLM.from_config(llm_config, torch_dtype=torch_dtype) # Load all weights from safetensors model_path = Path(pretrained_model_name_or_path) if model_path.exists(): safetensors_files = sorted(model_path.glob("*.safetensors")) else: from huggingface_hub import hf_hub_download, list_repo_files try: # Get list of safetensor files from the repo repo_files = list_repo_files(pretrained_model_name_or_path) sf_files = [f for f in repo_files if f.endswith(".safetensors")] safetensors_files = [] for sf in sf_files: sf_path = hf_hub_download(pretrained_model_name_or_path, sf) safetensors_files.append(Path(sf_path)) except Exception: safetensors_files = [] vision_state_dict = {} projector_state_dict = {} llm_state_dict = {} for sf_file in safetensors_files: state_dict = load_file(sf_file) for key, value in state_dict.items(): # Convert dtype if needed if torch_dtype is not None: value = value.to(torch_dtype) if key.startswith("vision_encoder."): new_key = key.replace("vision_encoder.", "") vision_state_dict[new_key] = value elif key.startswith("projector."): new_key = key.replace("projector.", "") projector_state_dict[new_key] = value elif key.startswith("language_model."): # LLM weights - strip the language_model. prefix new_key = key.replace("language_model.", "") llm_state_dict[new_key] = value else: # LLM weights without prefix (legacy format) llm_state_dict[key] = value # Load weights into model components # Note: vision_encoder uses OpenCLIP pretrained weights, not from safetensors if projector_state_dict: model.projector.load_state_dict(projector_state_dict, strict=False) if llm_state_dict: model.language_model.load_state_dict(llm_state_dict, strict=False) # Convert model to target dtype AFTER loading weights # load_state_dict doesn't change the model's dtype, so we must convert explicitly if torch_dtype is not None: model.projector = model.projector.to(dtype=torch_dtype) model.language_model = model.language_model.to(dtype=torch_dtype) # Handle device_map if device_map is not None: import accelerate if device_map == "auto": # Infer device map automatically device_map = accelerate.infer_auto_device_map( model, max_memory=None, no_split_module_classes=["MLPProjector", "ViTEncoder"], ) if isinstance(device_map, dict): model = accelerate.dispatch_model(model, device_map=device_map) else: # Simple device placement model = model.to(device_map) return model