Spaces:
Running
Running
| """Backbone model loading utilities. | |
| Handles loading of all 3 backbone VLMs (LLaVA-1.5, Qwen2.5, Gemma-3), | |
| dtype management, and layer name verification. | |
| """ | |
| import logging | |
| from typing import Optional | |
| import torch | |
| import yaml | |
| logger = logging.getLogger(__name__) | |
| # Layer prefix mapping — verified per backbone | |
| # LLaVA-1.5 : LlavaForConditionalGeneration → model.model.layers | |
| # Qwen2.5-VL : Qwen2_5_VLForConditionalGeneration → model.language_model.layers | |
| # Gemma-3 : Gemma3ForConditionalGeneration → model.language_model.layers | |
| LAYER_PREFIXES = { | |
| "llava-hf/llava-1.5-7b-hf": "model.layers", | |
| "Qwen/Qwen2.5-VL-3B-Instruct": "model.language_model.layers", | |
| "google/gemma-3-4b-it": "model.language_model.layers", | |
| } | |
| DTYPE_MAP = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| def load_config(config_path: str = "configs/experiment.yaml") -> dict: | |
| """Load experiment configuration.""" | |
| with open(config_path) as f: | |
| return yaml.safe_load(f) | |
| def get_backbone_config(config: dict, backbone_name: str = "primary") -> dict: | |
| """Get backbone config by name (primary) or index (transfer.0, transfer.1).""" | |
| if backbone_name == "primary": | |
| return config["backbones"]["primary"] | |
| elif backbone_name.startswith("transfer"): | |
| idx = int(backbone_name.split(".")[-1]) if "." in backbone_name else 0 | |
| return config["backbones"]["transfer"][idx] | |
| else: | |
| # Try matching by hf_id | |
| if config["backbones"]["primary"]["hf_id"] == backbone_name: | |
| return config["backbones"]["primary"] | |
| for t in config["backbones"]["transfer"]: | |
| if t["hf_id"] == backbone_name: | |
| return t | |
| raise ValueError(f"Unknown backbone: {backbone_name}") | |
| def load_backbone( | |
| hf_id: str, | |
| dtype: str = "float16", | |
| device: str = "cuda", | |
| cache_dir: Optional[str] = None, | |
| ): | |
| """Load a backbone model and processor. | |
| Returns: | |
| (model, processor) tuple | |
| """ | |
| from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer | |
| torch_dtype = DTYPE_MAP.get(dtype, torch.float16) | |
| logger.info(f"Loading backbone: {hf_id} (dtype={dtype}, device={device})") | |
| if "llava" in hf_id.lower(): | |
| from transformers import LlavaForConditionalGeneration, AutoProcessor | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| hf_id, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| cache_dir=cache_dir, | |
| ) | |
| processor = AutoProcessor.from_pretrained(hf_id, cache_dir=cache_dir) | |
| elif "qwen" in hf_id.lower() and "vl" in hf_id.lower(): | |
| # Qwen2.5-VL is a vision-language model (Qwen2_5_VLForConditionalGeneration) | |
| # with layers at model.language_model.layers | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor as _AP | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| hf_id, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| cache_dir=cache_dir, | |
| ) | |
| processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir) | |
| elif "gemma-3" in hf_id.lower() or "gemma3" in hf_id.lower(): | |
| # Gemma-3 is a vision-language model (Gemma3ForConditionalGeneration) | |
| # with layers at model.language_model.layers | |
| # Use attn_implementation="eager" to avoid SDPA or_mask_function | |
| # which requires torch>=2.6 (we have 2.5.x) | |
| from transformers import Gemma3ForConditionalGeneration, AutoProcessor as _AP | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| hf_id, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| cache_dir=cache_dir, | |
| attn_implementation="eager", | |
| ) | |
| processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| hf_id, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| cache_dir=cache_dir, | |
| ) | |
| processor = AutoTokenizer.from_pretrained(hf_id, cache_dir=cache_dir) | |
| model.eval() | |
| return model, processor | |
| def get_layer_module(model, layer_idx: int, hf_id: str): | |
| """Get a specific layer module by index. | |
| Args: | |
| model: The loaded model | |
| layer_idx: Layer index (0-based) | |
| hf_id: HuggingFace model identifier | |
| Returns: | |
| The layer module | |
| """ | |
| prefix = LAYER_PREFIXES.get(hf_id, "model.layers") | |
| layer_path = f"{prefix}.{layer_idx}" | |
| module = model | |
| for attr in layer_path.split("."): | |
| if attr.isdigit(): | |
| module = module[int(attr)] | |
| else: | |
| module = getattr(module, attr) | |
| return module | |
| def print_layer_names(model, max_depth: int = 3): | |
| """Print model layer names for verification. | |
| This MUST be called during S1 scaffold to verify layer paths | |
| for each backbone. | |
| """ | |
| logger.info("=== Model Layer Names ===") | |
| for name, module in model.named_modules(): | |
| depth = name.count(".") | |
| if depth <= max_depth: | |
| logger.info(f" {name}: {type(module).__name__}") | |
| logger.info("=========================") | |
| def get_num_layers(model, hf_id: str) -> int: | |
| """Get the number of decoder layers in the model.""" | |
| prefix = LAYER_PREFIXES.get(hf_id, "model.layers") | |
| module = model | |
| try: | |
| for attr in prefix.split("."): | |
| module = getattr(module, attr) | |
| return len(module) | |
| except AttributeError: | |
| # Fallback: try common paths | |
| for path in ["model.language_model.layers", "model.model.layers", "model.layers"]: | |
| try: | |
| m = model | |
| for attr in path.split("."): | |
| m = getattr(m, attr) | |
| return len(m) | |
| except AttributeError: | |
| continue | |
| raise AttributeError(f"Cannot determine num_layers for {hf_id}") | |
| def get_hidden_dim(model, hf_id: str) -> int: | |
| """Get hidden dimension of the model.""" | |
| if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_size"): | |
| return model.config.text_config.hidden_size | |
| elif hasattr(model.config, "hidden_size"): | |
| return model.config.hidden_size | |
| else: | |
| raise AttributeError(f"Cannot determine hidden_dim for {hf_id}") | |