File size: 6,416 Bytes
e6f24ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Backbone model loading utilities.

Handles loading of all 3 backbone VLMs (LLaVA-1.5, Qwen2.5, Gemma-3),
dtype management, and layer name verification.
"""

import logging
from typing import Optional

import torch
import yaml

logger = logging.getLogger(__name__)

# Layer prefix mapping — verified per backbone
#   LLaVA-1.5  : LlavaForConditionalGeneration       → model.model.layers
#   Qwen2.5-VL : Qwen2_5_VLForConditionalGeneration  → model.language_model.layers
#   Gemma-3    : Gemma3ForConditionalGeneration       → model.language_model.layers
LAYER_PREFIXES = {
    "llava-hf/llava-1.5-7b-hf": "model.layers",
    "Qwen/Qwen2.5-VL-3B-Instruct": "model.language_model.layers",
    "google/gemma-3-4b-it": "model.language_model.layers",
}

DTYPE_MAP = {
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
}


def load_config(config_path: str = "configs/experiment.yaml") -> dict:
    """Load experiment configuration."""
    with open(config_path) as f:
        return yaml.safe_load(f)


def get_backbone_config(config: dict, backbone_name: str = "primary") -> dict:
    """Get backbone config by name (primary) or index (transfer.0, transfer.1)."""
    if backbone_name == "primary":
        return config["backbones"]["primary"]
    elif backbone_name.startswith("transfer"):
        idx = int(backbone_name.split(".")[-1]) if "." in backbone_name else 0
        return config["backbones"]["transfer"][idx]
    else:
        # Try matching by hf_id
        if config["backbones"]["primary"]["hf_id"] == backbone_name:
            return config["backbones"]["primary"]
        for t in config["backbones"]["transfer"]:
            if t["hf_id"] == backbone_name:
                return t
        raise ValueError(f"Unknown backbone: {backbone_name}")


def load_backbone(
    hf_id: str,
    dtype: str = "float16",
    device: str = "cuda",
    cache_dir: Optional[str] = None,
):
    """Load a backbone model and processor.

    Returns:
        (model, processor) tuple
    """
    from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer

    torch_dtype = DTYPE_MAP.get(dtype, torch.float16)
    logger.info(f"Loading backbone: {hf_id} (dtype={dtype}, device={device})")

    if "llava" in hf_id.lower():
        from transformers import LlavaForConditionalGeneration, AutoProcessor

        model = LlavaForConditionalGeneration.from_pretrained(
            hf_id,
            torch_dtype=torch_dtype,
            device_map=device,
            cache_dir=cache_dir,
        )
        processor = AutoProcessor.from_pretrained(hf_id, cache_dir=cache_dir)
    elif "qwen" in hf_id.lower() and "vl" in hf_id.lower():
        # Qwen2.5-VL is a vision-language model (Qwen2_5_VLForConditionalGeneration)
        # with layers at model.language_model.layers
        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor as _AP
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            hf_id,
            torch_dtype=torch_dtype,
            device_map=device,
            cache_dir=cache_dir,
        )
        processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir)
    elif "gemma-3" in hf_id.lower() or "gemma3" in hf_id.lower():
        # Gemma-3 is a vision-language model (Gemma3ForConditionalGeneration)
        # with layers at model.language_model.layers
        # Use attn_implementation="eager" to avoid SDPA or_mask_function
        # which requires torch>=2.6 (we have 2.5.x)
        from transformers import Gemma3ForConditionalGeneration, AutoProcessor as _AP
        model = Gemma3ForConditionalGeneration.from_pretrained(
            hf_id,
            torch_dtype=torch_dtype,
            device_map=device,
            cache_dir=cache_dir,
            attn_implementation="eager",
        )
        processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            hf_id,
            torch_dtype=torch_dtype,
            device_map=device,
            cache_dir=cache_dir,
        )
        processor = AutoTokenizer.from_pretrained(hf_id, cache_dir=cache_dir)

    model.eval()
    return model, processor


def get_layer_module(model, layer_idx: int, hf_id: str):
    """Get a specific layer module by index.

    Args:
        model: The loaded model
        layer_idx: Layer index (0-based)
        hf_id: HuggingFace model identifier

    Returns:
        The layer module
    """
    prefix = LAYER_PREFIXES.get(hf_id, "model.layers")
    layer_path = f"{prefix}.{layer_idx}"

    module = model
    for attr in layer_path.split("."):
        if attr.isdigit():
            module = module[int(attr)]
        else:
            module = getattr(module, attr)
    return module


def print_layer_names(model, max_depth: int = 3):
    """Print model layer names for verification.

    This MUST be called during S1 scaffold to verify layer paths
    for each backbone.
    """
    logger.info("=== Model Layer Names ===")
    for name, module in model.named_modules():
        depth = name.count(".")
        if depth <= max_depth:
            logger.info(f"  {name}: {type(module).__name__}")
    logger.info("=========================")


def get_num_layers(model, hf_id: str) -> int:
    """Get the number of decoder layers in the model."""
    prefix = LAYER_PREFIXES.get(hf_id, "model.layers")
    module = model
    try:
        for attr in prefix.split("."):
            module = getattr(module, attr)
        return len(module)
    except AttributeError:
        # Fallback: try common paths
        for path in ["model.language_model.layers", "model.model.layers", "model.layers"]:
            try:
                m = model
                for attr in path.split("."):
                    m = getattr(m, attr)
                return len(m)
            except AttributeError:
                continue
        raise AttributeError(f"Cannot determine num_layers for {hf_id}")


def get_hidden_dim(model, hf_id: str) -> int:
    """Get hidden dimension of the model."""
    if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_size"):
        return model.config.text_config.hidden_size
    elif hasattr(model.config, "hidden_size"):
        return model.config.hidden_size
    else:
        raise AttributeError(f"Cannot determine hidden_dim for {hf_id}")