"""Backbone model loading utilities. Handles loading of all 3 backbone VLMs (LLaVA-1.5, Qwen2.5, Gemma-3), dtype management, and layer name verification. """ import logging from typing import Optional import torch import yaml logger = logging.getLogger(__name__) # Layer prefix mapping — verified per backbone # LLaVA-1.5 : LlavaForConditionalGeneration → model.model.layers # Qwen2.5-VL : Qwen2_5_VLForConditionalGeneration → model.language_model.layers # Gemma-3 : Gemma3ForConditionalGeneration → model.language_model.layers LAYER_PREFIXES = { "llava-hf/llava-1.5-7b-hf": "model.layers", "Qwen/Qwen2.5-VL-3B-Instruct": "model.language_model.layers", "google/gemma-3-4b-it": "model.language_model.layers", } DTYPE_MAP = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } def load_config(config_path: str = "configs/experiment.yaml") -> dict: """Load experiment configuration.""" with open(config_path) as f: return yaml.safe_load(f) def get_backbone_config(config: dict, backbone_name: str = "primary") -> dict: """Get backbone config by name (primary) or index (transfer.0, transfer.1).""" if backbone_name == "primary": return config["backbones"]["primary"] elif backbone_name.startswith("transfer"): idx = int(backbone_name.split(".")[-1]) if "." in backbone_name else 0 return config["backbones"]["transfer"][idx] else: # Try matching by hf_id if config["backbones"]["primary"]["hf_id"] == backbone_name: return config["backbones"]["primary"] for t in config["backbones"]["transfer"]: if t["hf_id"] == backbone_name: return t raise ValueError(f"Unknown backbone: {backbone_name}") def load_backbone( hf_id: str, dtype: str = "float16", device: str = "cuda", cache_dir: Optional[str] = None, ): """Load a backbone model and processor. Returns: (model, processor) tuple """ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer torch_dtype = DTYPE_MAP.get(dtype, torch.float16) logger.info(f"Loading backbone: {hf_id} (dtype={dtype}, device={device})") if "llava" in hf_id.lower(): from transformers import LlavaForConditionalGeneration, AutoProcessor model = LlavaForConditionalGeneration.from_pretrained( hf_id, torch_dtype=torch_dtype, device_map=device, cache_dir=cache_dir, ) processor = AutoProcessor.from_pretrained(hf_id, cache_dir=cache_dir) elif "qwen" in hf_id.lower() and "vl" in hf_id.lower(): # Qwen2.5-VL is a vision-language model (Qwen2_5_VLForConditionalGeneration) # with layers at model.language_model.layers from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor as _AP model = Qwen2_5_VLForConditionalGeneration.from_pretrained( hf_id, torch_dtype=torch_dtype, device_map=device, cache_dir=cache_dir, ) processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir) elif "gemma-3" in hf_id.lower() or "gemma3" in hf_id.lower(): # Gemma-3 is a vision-language model (Gemma3ForConditionalGeneration) # with layers at model.language_model.layers # Use attn_implementation="eager" to avoid SDPA or_mask_function # which requires torch>=2.6 (we have 2.5.x) from transformers import Gemma3ForConditionalGeneration, AutoProcessor as _AP model = Gemma3ForConditionalGeneration.from_pretrained( hf_id, torch_dtype=torch_dtype, device_map=device, cache_dir=cache_dir, attn_implementation="eager", ) processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir) else: model = AutoModelForCausalLM.from_pretrained( hf_id, torch_dtype=torch_dtype, device_map=device, cache_dir=cache_dir, ) processor = AutoTokenizer.from_pretrained(hf_id, cache_dir=cache_dir) model.eval() return model, processor def get_layer_module(model, layer_idx: int, hf_id: str): """Get a specific layer module by index. Args: model: The loaded model layer_idx: Layer index (0-based) hf_id: HuggingFace model identifier Returns: The layer module """ prefix = LAYER_PREFIXES.get(hf_id, "model.layers") layer_path = f"{prefix}.{layer_idx}" module = model for attr in layer_path.split("."): if attr.isdigit(): module = module[int(attr)] else: module = getattr(module, attr) return module def print_layer_names(model, max_depth: int = 3): """Print model layer names for verification. This MUST be called during S1 scaffold to verify layer paths for each backbone. """ logger.info("=== Model Layer Names ===") for name, module in model.named_modules(): depth = name.count(".") if depth <= max_depth: logger.info(f" {name}: {type(module).__name__}") logger.info("=========================") def get_num_layers(model, hf_id: str) -> int: """Get the number of decoder layers in the model.""" prefix = LAYER_PREFIXES.get(hf_id, "model.layers") module = model try: for attr in prefix.split("."): module = getattr(module, attr) return len(module) except AttributeError: # Fallback: try common paths for path in ["model.language_model.layers", "model.model.layers", "model.layers"]: try: m = model for attr in path.split("."): m = getattr(m, attr) return len(m) except AttributeError: continue raise AttributeError(f"Cannot determine num_layers for {hf_id}") def get_hidden_dim(model, hf_id: str) -> int: """Get hidden dimension of the model.""" if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_size"): return model.config.text_config.hidden_size elif hasattr(model.config, "hidden_size"): return model.config.hidden_size else: raise AttributeError(f"Cannot determine hidden_dim for {hf_id}")