Iris1.5 / model.py
DavidSeyserHF's picture
Add model.py source
14e3915 verified
"""IRIS-2 VLM.
Architecture (LLaVA-style):
- Vision: nvidia/RADIO [FROZEN]
- Language: ByteDance/Ouro-1.4B (LoopLM) [FROZEN, or LoRA adapters when ``use_lora``]
- Connector: configurable-depth MLP projector [TRAINED — always full weights]
Trainable parameters: the MLP projector; optionally PEFT LoRA on selected Ouro linears.
Default LoRA targets are attention + MLP projections; set ``lora_target_modules`` to
``["early_exit_gate"]`` to train **only** the universal-transformer / ACT exit head (and the
projector), leaving the rest of the LM frozen. Set ``lora_edge_layers: N`` to attach LoRA only
to the first **N** and last **N** Ouro decoder **blocks** (not LoRA rank), via PEFT
``layers_to_transform``.
"""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
RADIO_HUB_REPO = "NVlabs/RADIO"
RADIO_VERSION = "radio_v2.5-h"
OURO_NAME = "ByteDance/Ouro-1.4B"
@dataclass
class IrisConfig:
radio_version: str = RADIO_VERSION # torch.hub version tag
radio_repo: str = RADIO_HUB_REPO
ouro_name: str = OURO_NAME
image_size: int = 432 # must be divisible by RADIO patch size (16)
projector_hidden_mult: int = 1 # H = mult * llm_hidden
projector_num_intermediate: int = 1 # H-wide layers before last Linear → llm_hidden
torch_dtype: torch.dtype = torch.bfloat16
compile_mode: str | None = None # e.g. "default", "reduce-overhead", "max-autotune"
compile_dynamic: bool = True # allow dynamic seq lengths
# ``eager`` materializes attention weights (needed for viz); default lets HF pick (often SDPA).
llm_attn_implementation: Optional[str] = None
# Ouro LoRA (PEFT); base LM weights stay frozen, only adapter matrices train.
use_lora: bool = False
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
# Default (None): LoRA on attention + MLP linears. Use ``["early_exit_gate"]`` for ACT gate only.
lora_target_modules: Optional[list[str]] = None
# If set (e.g. 4), apply LoRA only on the first/last N **decoder** blocks (Ouro `layers.0` …
# `layers.N-1`), via PEFT ``layers_to_transform``. Ignored for `early_exit_gate`-only LoRA
# unless ``lora_dual_edge_and_gate`` is true. Cannot mix gate + layer in **one** adapter;
# use ``lora_dual_edge_and_gate: true`` instead (two PEFT adapters).
lora_edge_layers: Optional[int] = None
# If true: two LoRA adapters — (1) edge blocks on ``lora_target_modules`` or defaults,
# (2) ``early_exit_gate``. Requires ``lora_edge_layers`` and ``use_lora``.
lora_dual_edge_and_gate: bool = False
class RadioImageTransform:
"""Preprocess PIL images for RADIO.
RADIO ships its own ``input_conditioner`` that handles per-channel
normalization internally, so we only resize + tensorize to ``[0, 1]``.
"""
def __init__(self, image_size: int):
self.image_size = image_size
self._tx = T.Compose([
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(), # HWC uint8 -> CHW float in [0, 1]
])
def __call__(self, images, return_tensors: str = "pt"):
if isinstance(images, Image.Image):
images = [images]
tensors = [self._tx(img.convert("RGB")) for img in images]
pixel_values = torch.stack(tensors, dim=0)
if return_tensors == "pt":
return {"pixel_values": pixel_values}
return {"pixel_values": pixel_values.numpy()}
def _mlp_sequential(
dims: Sequence[int], act: type[nn.Module] = nn.GELU
) -> nn.Sequential:
"""``Linear`` stack ``dims[0] → … → dims[-1]`` with ``act()`` after all but the last."""
if len(dims) < 2:
raise ValueError("mlp must have at least in_dim and out_dim")
layers: list[nn.Module] = []
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i + 1]))
if i < len(dims) - 2:
layers.append(act())
return nn.Sequential(*layers)
class MLPProjector(nn.Module):
"""Nonlinear visual-token → LLM: ``num_intermediate + 1`` `Linear` layers, GELU between."""
def __init__(
self,
vision_dim: int,
llm_hidden: int,
hidden_mult: int = 1,
num_intermediate: int = 1,
):
super().__init__()
if num_intermediate < 1:
raise ValueError("num_intermediate must be >= 1")
h = llm_hidden * hidden_mult
dims: list[int] = [vision_dim] + [h] * num_intermediate + [llm_hidden]
self.net = _mlp_sequential(dims)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
OURO_DEFAULT_LORA_TARGETS: tuple[str, ...] = (
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
)
# Ouro ``OuroModel.early_exit_gate`` (Linear → sigmoid): per–UT-step exit / ACT-style head.
OURO_ACT_GATE_LORA_TARGETS: tuple[str, ...] = ("early_exit_gate",)
# Submodule names that live under ``model.layers.{i}`` (PEFT `layers_to_transform` applies).
OURO_LAYER_LORA_NAME_PREFIXES: frozenset[str] = frozenset(
{"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
)
def lora_layer_indices_for_edges(num_hidden_layers: int, edge: int) -> list[int]:
"""Indices ``0..edge-1`` and ``num_hidden_layers-edge .. num_hidden_layers-1`` (inclusive)."""
e = int(edge)
if e < 1:
raise ValueError("edge must be >= 1")
n = int(num_hidden_layers)
if n < 1:
raise ValueError("num_hidden_layers must be >= 1")
return list(range(e)) + list(range(max(0, n - e), n))
class IrisVLM(nn.Module):
"""Frozen RADIO + trainable MLP projector + Ouro-1.4B (frozen or LoRA)."""
def __init__(self, cfg: Optional[IrisConfig] = None):
super().__init__()
self.cfg = cfg or IrisConfig()
# nvidia/RADIO via torch.hub (HF's remote code is incompatible with
# transformers >= 4.55's per-parameter weight loading). Kept in fp32
# because RADIO's input_conditioner has fp32-pinned buffers; we run
# the forward under autocast instead of casting weights.
self.vision = torch.hub.load(
self.cfg.radio_repo,
"radio_model",
version=self.cfg.radio_version,
progress=True,
skip_validation=True,
)
self.image_processor = RadioImageTransform(self.cfg.image_size)
self.tokenizer = AutoTokenizer.from_pretrained(
self.cfg.ouro_name, trust_remote_code=True
)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# OuroConfig omits pad_token_id; inject from the tokenizer so the
# model's __init__ can read config.pad_token_id.
llm_config = AutoConfig.from_pretrained(self.cfg.ouro_name, trust_remote_code=True)
if getattr(llm_config, "pad_token_id", None) is None:
llm_config.pad_token_id = self.tokenizer.pad_token_id
_llm_kw: dict = dict(
pretrained_model_name_or_path=self.cfg.ouro_name,
config=llm_config,
trust_remote_code=True,
torch_dtype=self.cfg.torch_dtype,
)
if self.cfg.llm_attn_implementation is not None:
_llm_kw["attn_implementation"] = self.cfg.llm_attn_implementation
self.llm = AutoModelForCausalLM.from_pretrained(**_llm_kw)
# Freeze RADIO and the Ouro *base* weights (PEFT adds trainable LoRA side matrices).
for p in self.vision.parameters():
p.requires_grad_(False)
for p in self.llm.parameters():
p.requires_grad_(False)
self.vision.eval()
self._use_lora = bool(self.cfg.use_lora)
self._lora_dual_edge_gate = False
if self._use_lora:
from peft import LoraConfig, TaskType, get_peft_model
if self.cfg.lora_dual_edge_and_gate:
if self.cfg.lora_edge_layers is None:
raise ValueError("lora_dual_edge_and_gate requires lora_edge_layers (e.g. 4).")
edge_targets = self.cfg.lora_target_modules
if not edge_targets:
edge_targets = list(OURO_DEFAULT_LORA_TARGETS)
if "early_exit_gate" in edge_targets:
raise ValueError(
"lora_target_modules must not include early_exit_gate when using "
"lora_dual_edge_and_gate (the gate uses a separate adapter)."
)
n = int(self.llm.config.num_hidden_layers)
lt = lora_layer_indices_for_edges(n, int(self.cfg.lora_edge_layers))
common = dict(
r=self.cfg.lora_r,
lora_alpha=self.cfg.lora_alpha,
lora_dropout=self.cfg.lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
cfg_edge = LoraConfig(
target_modules=edge_targets,
layers_to_transform=lt,
**common,
)
cfg_gate = LoraConfig(
target_modules=list(OURO_ACT_GATE_LORA_TARGETS),
**common,
)
self.llm = get_peft_model(self.llm, cfg_edge, adapter_name="edge_lora")
self.llm.add_adapter("gate_lora", cfg_gate)
self.llm.base_model.set_adapter(["edge_lora", "gate_lora"])
self._lora_dual_edge_gate = True
else:
targets = self.cfg.lora_target_modules
if not targets:
targets = list(OURO_DEFAULT_LORA_TARGETS)
ts_set = frozenset(targets)
layerish = bool(ts_set & OURO_LAYER_LORA_NAME_PREFIXES)
has_act_gate = "early_exit_gate" in ts_set
if has_act_gate and layerish:
raise ValueError(
"lora_target_modules cannot list both `early_exit_gate` and layer modules "
"(q_proj, mlp, …) in one PEFT adapter; set lora_dual_edge_and_gate: true "
"with lora_edge_layers, or use one target family only."
)
peft_kw: dict = dict(
r=self.cfg.lora_r,
lora_alpha=self.cfg.lora_alpha,
lora_dropout=self.cfg.lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=targets,
)
el = self.cfg.lora_edge_layers
if el is not None:
if not layerish:
pass
else:
n = int(self.llm.config.num_hidden_layers)
lt = lora_layer_indices_for_edges(n, int(el))
peft_kw["layers_to_transform"] = lt
self.llm = get_peft_model(self.llm, LoraConfig(**peft_kw))
else:
self.llm.eval()
vision_dim = self._probe_vision_dim()
llm_hidden = self.llm.config.hidden_size
self.projector = MLPProjector(
vision_dim=vision_dim,
llm_hidden=llm_hidden,
hidden_mult=self.cfg.projector_hidden_mult,
num_intermediate=self.cfg.projector_num_intermediate,
).to(self.cfg.torch_dtype)
def compile_components(self, mode: str = "default", dynamic: bool = True) -> "IrisVLM":
"""Wrap heavy submodules with torch.compile.
Only the frozen vision + LM forwards matter for throughput; the
projector is tiny so we skip it (also keeps its ``state_dict`` keys
prefix-free for checkpoints). ``dynamic=True`` avoids recompiles when
prompt / response lengths vary across batches; we also raise the
dynamo cache limit to absorb the first few shape specializations.
PEFT-wrapped LMs are not ``torch.compile``'d (unsupported / fragile).
"""
import torch._dynamo as _dynamo
_dynamo.config.cache_size_limit = max(_dynamo.config.cache_size_limit, 64)
self.vision = torch.compile(self.vision, mode=mode, dynamic=dynamic, fullgraph=False)
if not self._use_lora:
self.llm = torch.compile(self.llm, mode=mode, dynamic=dynamic, fullgraph=False)
return self
@staticmethod
def _spatial_features(out) -> torch.Tensor:
"""Extract dense (B,T,D) features from a RADIO forward output."""
for name in ("features", "spatial_features"):
v = getattr(out, name, None)
if v is not None:
return v
if isinstance(out, tuple) and len(out) >= 2:
return out[1]
raise TypeError(f"Cannot find spatial features on RADIO output: {type(out)}")
def _vision_autocast(self):
"""Autocast context matching cfg.torch_dtype (bf16/fp16), or no-op for fp32."""
dtype = self.cfg.torch_dtype
if dtype in (torch.bfloat16, torch.float16):
device = next(self.vision.parameters()).device
return torch.autocast(device_type=device.type, dtype=dtype)
import contextlib
return contextlib.nullcontext()
@torch.no_grad()
def _probe_vision_dim(self) -> int:
"""Run RADIO once on a dummy image to discover spatial_features dim."""
device = next(self.vision.parameters()).device
dummy = torch.zeros(
1, 3, self.cfg.image_size, self.cfg.image_size,
device=device, dtype=torch.float32,
)
with self._vision_autocast():
spatial = self._spatial_features(self.vision(dummy))
assert spatial.dim() == 3, f"expected (B,T,D), got {tuple(spatial.shape)}"
return spatial.shape[-1]
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""(B,3,H,W) -> (B,T,llm_hidden) visual tokens in LLM embedding space.
RADIO runs in fp32 weights under autocast so its mixed fp32 buffers
(input_conditioner) stay consistent; output is cast to the projector
dtype before the trainable MLP.
"""
with torch.no_grad(), self._vision_autocast():
spatial = self._spatial_features(self.vision(pixel_values.float()))
return self.projector(
spatial.to(self.projector.net[0].weight.dtype)
)
def _embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.llm.get_input_embeddings()(input_ids)
def forward(
self,
pixel_values: torch.Tensor, # (B, 3, H, W)
prompt_ids: torch.Tensor | None = None,
prompt_mask: torch.Tensor | None = None,
response_ids: torch.Tensor | None = None,
response_mask: torch.Tensor | None = None,
packed_text_ids: torch.Tensor | None = None,
packed_text_mask: torch.Tensor | None = None,
packed_text_labels: torch.Tensor | None = None,
**kwargs: Any,
) -> torch.Tensor:
"""Vision-first forward.
**Single-turn (default):** ``[visual, prompt, response]`` — CE only on
``response`` tokens (see ``prompt_ids`` … ``response_mask``).
**Multiturn:** pass ``packed_text_ids``, ``packed_text_mask``, and
``packed_text_labels`` instead of prompt/response tensors. Layout is
``[visual, packed_text]`` with ``packed_text_labels`` already ``-100`` on
user spans and padding (loss on every assistant span in one sequence).
Extra keyword args are ignored so callers can pass through batch dicts.
"""
_ = kwargs # allow batch dicts with unused keys
B = pixel_values.size(0)
device = pixel_values.device
visual = self.encode_images(pixel_values) # (B, T, H)
T = visual.size(1)
if packed_text_ids is not None:
if packed_text_mask is None or packed_text_labels is None:
raise ValueError(
"packed_text_ids requires packed_text_mask and packed_text_labels"
)
text_emb = self._embed_tokens(packed_text_ids)
inputs_embeds = torch.cat([visual, text_emb], dim=1)
visual_mask = torch.ones(B, T, dtype=packed_text_mask.dtype, device=device)
attention_mask = torch.cat([visual_mask, packed_text_mask], dim=1)
ignore_vis = torch.full(
(B, T), -100, dtype=torch.long, device=device
)
labels = torch.cat([ignore_vis, packed_text_labels], dim=1)
return self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
use_cache=False,
)
if (
prompt_ids is None
or prompt_mask is None
or response_ids is None
or response_mask is None
):
raise ValueError(
"Either packed_text_* (multiturn) or prompt_ids/response_ids (single-turn) is required"
)
prompt_emb = self._embed_tokens(prompt_ids) # (B, Lp, H)
resp_emb = self._embed_tokens(response_ids) # (B, Lr, H)
inputs_embeds = torch.cat([visual, prompt_emb, resp_emb], dim=1) # (B, L, H)
visual_mask = torch.ones(B, T, dtype=prompt_mask.dtype, device=device)
attention_mask = torch.cat([visual_mask, prompt_mask, response_mask], dim=1)
ignore = torch.full(
(B, T + prompt_ids.size(1)), -100, dtype=torch.long, device=device
)
resp_labels = response_ids.masked_fill(response_mask == 0, -100)
labels = torch.cat([ignore, resp_labels], dim=1)
return self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
use_cache=False,
)
def trainable_parameters(self) -> list[torch.nn.Parameter]:
out = [p for p in self.projector.parameters() if p.requires_grad]
if self._use_lora:
out.extend(p for p in self.llm.parameters() if p.requires_grad)
return out