stylsteer-vlm / src /utils /backbone.py
abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Backbone model loading utilities.
Handles loading of all 3 backbone VLMs (LLaVA-1.5, Qwen2.5, Gemma-3),
dtype management, and layer name verification.
"""
import logging
from typing import Optional
import torch
import yaml
logger = logging.getLogger(__name__)
# Layer prefix mapping — verified per backbone
# LLaVA-1.5 : LlavaForConditionalGeneration → model.model.layers
# Qwen2.5-VL : Qwen2_5_VLForConditionalGeneration → model.language_model.layers
# Gemma-3 : Gemma3ForConditionalGeneration → model.language_model.layers
LAYER_PREFIXES = {
"llava-hf/llava-1.5-7b-hf": "model.layers",
"Qwen/Qwen2.5-VL-3B-Instruct": "model.language_model.layers",
"google/gemma-3-4b-it": "model.language_model.layers",
}
DTYPE_MAP = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
def load_config(config_path: str = "configs/experiment.yaml") -> dict:
"""Load experiment configuration."""
with open(config_path) as f:
return yaml.safe_load(f)
def get_backbone_config(config: dict, backbone_name: str = "primary") -> dict:
"""Get backbone config by name (primary) or index (transfer.0, transfer.1)."""
if backbone_name == "primary":
return config["backbones"]["primary"]
elif backbone_name.startswith("transfer"):
idx = int(backbone_name.split(".")[-1]) if "." in backbone_name else 0
return config["backbones"]["transfer"][idx]
else:
# Try matching by hf_id
if config["backbones"]["primary"]["hf_id"] == backbone_name:
return config["backbones"]["primary"]
for t in config["backbones"]["transfer"]:
if t["hf_id"] == backbone_name:
return t
raise ValueError(f"Unknown backbone: {backbone_name}")
def load_backbone(
hf_id: str,
dtype: str = "float16",
device: str = "cuda",
cache_dir: Optional[str] = None,
):
"""Load a backbone model and processor.
Returns:
(model, processor) tuple
"""
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
torch_dtype = DTYPE_MAP.get(dtype, torch.float16)
logger.info(f"Loading backbone: {hf_id} (dtype={dtype}, device={device})")
if "llava" in hf_id.lower():
from transformers import LlavaForConditionalGeneration, AutoProcessor
model = LlavaForConditionalGeneration.from_pretrained(
hf_id,
torch_dtype=torch_dtype,
device_map=device,
cache_dir=cache_dir,
)
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=cache_dir)
elif "qwen" in hf_id.lower() and "vl" in hf_id.lower():
# Qwen2.5-VL is a vision-language model (Qwen2_5_VLForConditionalGeneration)
# with layers at model.language_model.layers
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor as _AP
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
hf_id,
torch_dtype=torch_dtype,
device_map=device,
cache_dir=cache_dir,
)
processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir)
elif "gemma-3" in hf_id.lower() or "gemma3" in hf_id.lower():
# Gemma-3 is a vision-language model (Gemma3ForConditionalGeneration)
# with layers at model.language_model.layers
# Use attn_implementation="eager" to avoid SDPA or_mask_function
# which requires torch>=2.6 (we have 2.5.x)
from transformers import Gemma3ForConditionalGeneration, AutoProcessor as _AP
model = Gemma3ForConditionalGeneration.from_pretrained(
hf_id,
torch_dtype=torch_dtype,
device_map=device,
cache_dir=cache_dir,
attn_implementation="eager",
)
processor = _AP.from_pretrained(hf_id, cache_dir=cache_dir)
else:
model = AutoModelForCausalLM.from_pretrained(
hf_id,
torch_dtype=torch_dtype,
device_map=device,
cache_dir=cache_dir,
)
processor = AutoTokenizer.from_pretrained(hf_id, cache_dir=cache_dir)
model.eval()
return model, processor
def get_layer_module(model, layer_idx: int, hf_id: str):
"""Get a specific layer module by index.
Args:
model: The loaded model
layer_idx: Layer index (0-based)
hf_id: HuggingFace model identifier
Returns:
The layer module
"""
prefix = LAYER_PREFIXES.get(hf_id, "model.layers")
layer_path = f"{prefix}.{layer_idx}"
module = model
for attr in layer_path.split("."):
if attr.isdigit():
module = module[int(attr)]
else:
module = getattr(module, attr)
return module
def print_layer_names(model, max_depth: int = 3):
"""Print model layer names for verification.
This MUST be called during S1 scaffold to verify layer paths
for each backbone.
"""
logger.info("=== Model Layer Names ===")
for name, module in model.named_modules():
depth = name.count(".")
if depth <= max_depth:
logger.info(f" {name}: {type(module).__name__}")
logger.info("=========================")
def get_num_layers(model, hf_id: str) -> int:
"""Get the number of decoder layers in the model."""
prefix = LAYER_PREFIXES.get(hf_id, "model.layers")
module = model
try:
for attr in prefix.split("."):
module = getattr(module, attr)
return len(module)
except AttributeError:
# Fallback: try common paths
for path in ["model.language_model.layers", "model.model.layers", "model.layers"]:
try:
m = model
for attr in path.split("."):
m = getattr(m, attr)
return len(m)
except AttributeError:
continue
raise AttributeError(f"Cannot determine num_layers for {hf_id}")
def get_hidden_dim(model, hf_id: str) -> int:
"""Get hidden dimension of the model."""
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_size"):
return model.config.text_config.hidden_size
elif hasattr(model.config, "hidden_size"):
return model.config.hidden_size
else:
raise AttributeError(f"Cannot determine hidden_dim for {hf_id}")