"""IRIS-2 VLM. Architecture (LLaVA-style): - Vision: nvidia/RADIO [FROZEN] - Language: ByteDance/Ouro-1.4B (LoopLM) [FROZEN, or LoRA adapters when ``use_lora``] - Connector: configurable-depth MLP projector [TRAINED — always full weights] Trainable parameters: the MLP projector; optionally PEFT LoRA on selected Ouro linears. Default LoRA targets are attention + MLP projections; set ``lora_target_modules`` to ``["early_exit_gate"]`` to train **only** the universal-transformer / ACT exit head (and the projector), leaving the rest of the LM frozen. Set ``lora_edge_layers: N`` to attach LoRA only to the first **N** and last **N** Ouro decoder **blocks** (not LoRA rank), via PEFT ``layers_to_transform``. """ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Optional import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer RADIO_HUB_REPO = "NVlabs/RADIO" RADIO_VERSION = "radio_v2.5-h" OURO_NAME = "ByteDance/Ouro-1.4B" @dataclass class IrisConfig: radio_version: str = RADIO_VERSION # torch.hub version tag radio_repo: str = RADIO_HUB_REPO ouro_name: str = OURO_NAME image_size: int = 432 # must be divisible by RADIO patch size (16) projector_hidden_mult: int = 1 # H = mult * llm_hidden projector_num_intermediate: int = 1 # H-wide layers before last Linear → llm_hidden torch_dtype: torch.dtype = torch.bfloat16 compile_mode: str | None = None # e.g. "default", "reduce-overhead", "max-autotune" compile_dynamic: bool = True # allow dynamic seq lengths # ``eager`` materializes attention weights (needed for viz); default lets HF pick (often SDPA). llm_attn_implementation: Optional[str] = None # Ouro LoRA (PEFT); base LM weights stay frozen, only adapter matrices train. use_lora: bool = False lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 # Default (None): LoRA on attention + MLP linears. Use ``["early_exit_gate"]`` for ACT gate only. lora_target_modules: Optional[list[str]] = None # If set (e.g. 4), apply LoRA only on the first/last N **decoder** blocks (Ouro `layers.0` … # `layers.N-1`), via PEFT ``layers_to_transform``. Ignored for `early_exit_gate`-only LoRA # unless ``lora_dual_edge_and_gate`` is true. Cannot mix gate + layer in **one** adapter; # use ``lora_dual_edge_and_gate: true`` instead (two PEFT adapters). lora_edge_layers: Optional[int] = None # If true: two LoRA adapters — (1) edge blocks on ``lora_target_modules`` or defaults, # (2) ``early_exit_gate``. Requires ``lora_edge_layers`` and ``use_lora``. lora_dual_edge_and_gate: bool = False class RadioImageTransform: """Preprocess PIL images for RADIO. RADIO ships its own ``input_conditioner`` that handles per-channel normalization internally, so we only resize + tensorize to ``[0, 1]``. """ def __init__(self, image_size: int): self.image_size = image_size self._tx = T.Compose([ T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), # HWC uint8 -> CHW float in [0, 1] ]) def __call__(self, images, return_tensors: str = "pt"): if isinstance(images, Image.Image): images = [images] tensors = [self._tx(img.convert("RGB")) for img in images] pixel_values = torch.stack(tensors, dim=0) if return_tensors == "pt": return {"pixel_values": pixel_values} return {"pixel_values": pixel_values.numpy()} def _mlp_sequential( dims: Sequence[int], act: type[nn.Module] = nn.GELU ) -> nn.Sequential: """``Linear`` stack ``dims[0] → … → dims[-1]`` with ``act()`` after all but the last.""" if len(dims) < 2: raise ValueError("mlp must have at least in_dim and out_dim") layers: list[nn.Module] = [] for i in range(len(dims) - 1): layers.append(nn.Linear(dims[i], dims[i + 1])) if i < len(dims) - 2: layers.append(act()) return nn.Sequential(*layers) class MLPProjector(nn.Module): """Nonlinear visual-token → LLM: ``num_intermediate + 1`` `Linear` layers, GELU between.""" def __init__( self, vision_dim: int, llm_hidden: int, hidden_mult: int = 1, num_intermediate: int = 1, ): super().__init__() if num_intermediate < 1: raise ValueError("num_intermediate must be >= 1") h = llm_hidden * hidden_mult dims: list[int] = [vision_dim] + [h] * num_intermediate + [llm_hidden] self.net = _mlp_sequential(dims) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) OURO_DEFAULT_LORA_TARGETS: tuple[str, ...] = ( "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ) # Ouro ``OuroModel.early_exit_gate`` (Linear → sigmoid): per–UT-step exit / ACT-style head. OURO_ACT_GATE_LORA_TARGETS: tuple[str, ...] = ("early_exit_gate",) # Submodule names that live under ``model.layers.{i}`` (PEFT `layers_to_transform` applies). OURO_LAYER_LORA_NAME_PREFIXES: frozenset[str] = frozenset( {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} ) def lora_layer_indices_for_edges(num_hidden_layers: int, edge: int) -> list[int]: """Indices ``0..edge-1`` and ``num_hidden_layers-edge .. num_hidden_layers-1`` (inclusive).""" e = int(edge) if e < 1: raise ValueError("edge must be >= 1") n = int(num_hidden_layers) if n < 1: raise ValueError("num_hidden_layers must be >= 1") return list(range(e)) + list(range(max(0, n - e), n)) class IrisVLM(nn.Module): """Frozen RADIO + trainable MLP projector + Ouro-1.4B (frozen or LoRA).""" def __init__(self, cfg: Optional[IrisConfig] = None): super().__init__() self.cfg = cfg or IrisConfig() # nvidia/RADIO via torch.hub (HF's remote code is incompatible with # transformers >= 4.55's per-parameter weight loading). Kept in fp32 # because RADIO's input_conditioner has fp32-pinned buffers; we run # the forward under autocast instead of casting weights. self.vision = torch.hub.load( self.cfg.radio_repo, "radio_model", version=self.cfg.radio_version, progress=True, skip_validation=True, ) self.image_processor = RadioImageTransform(self.cfg.image_size) self.tokenizer = AutoTokenizer.from_pretrained( self.cfg.ouro_name, trust_remote_code=True ) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token # OuroConfig omits pad_token_id; inject from the tokenizer so the # model's __init__ can read config.pad_token_id. llm_config = AutoConfig.from_pretrained(self.cfg.ouro_name, trust_remote_code=True) if getattr(llm_config, "pad_token_id", None) is None: llm_config.pad_token_id = self.tokenizer.pad_token_id _llm_kw: dict = dict( pretrained_model_name_or_path=self.cfg.ouro_name, config=llm_config, trust_remote_code=True, torch_dtype=self.cfg.torch_dtype, ) if self.cfg.llm_attn_implementation is not None: _llm_kw["attn_implementation"] = self.cfg.llm_attn_implementation self.llm = AutoModelForCausalLM.from_pretrained(**_llm_kw) # Freeze RADIO and the Ouro *base* weights (PEFT adds trainable LoRA side matrices). for p in self.vision.parameters(): p.requires_grad_(False) for p in self.llm.parameters(): p.requires_grad_(False) self.vision.eval() self._use_lora = bool(self.cfg.use_lora) self._lora_dual_edge_gate = False if self._use_lora: from peft import LoraConfig, TaskType, get_peft_model if self.cfg.lora_dual_edge_and_gate: if self.cfg.lora_edge_layers is None: raise ValueError("lora_dual_edge_and_gate requires lora_edge_layers (e.g. 4).") edge_targets = self.cfg.lora_target_modules if not edge_targets: edge_targets = list(OURO_DEFAULT_LORA_TARGETS) if "early_exit_gate" in edge_targets: raise ValueError( "lora_target_modules must not include early_exit_gate when using " "lora_dual_edge_and_gate (the gate uses a separate adapter)." ) n = int(self.llm.config.num_hidden_layers) lt = lora_layer_indices_for_edges(n, int(self.cfg.lora_edge_layers)) common = dict( r=self.cfg.lora_r, lora_alpha=self.cfg.lora_alpha, lora_dropout=self.cfg.lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM, ) cfg_edge = LoraConfig( target_modules=edge_targets, layers_to_transform=lt, **common, ) cfg_gate = LoraConfig( target_modules=list(OURO_ACT_GATE_LORA_TARGETS), **common, ) self.llm = get_peft_model(self.llm, cfg_edge, adapter_name="edge_lora") self.llm.add_adapter("gate_lora", cfg_gate) self.llm.base_model.set_adapter(["edge_lora", "gate_lora"]) self._lora_dual_edge_gate = True else: targets = self.cfg.lora_target_modules if not targets: targets = list(OURO_DEFAULT_LORA_TARGETS) ts_set = frozenset(targets) layerish = bool(ts_set & OURO_LAYER_LORA_NAME_PREFIXES) has_act_gate = "early_exit_gate" in ts_set if has_act_gate and layerish: raise ValueError( "lora_target_modules cannot list both `early_exit_gate` and layer modules " "(q_proj, mlp, …) in one PEFT adapter; set lora_dual_edge_and_gate: true " "with lora_edge_layers, or use one target family only." ) peft_kw: dict = dict( r=self.cfg.lora_r, lora_alpha=self.cfg.lora_alpha, lora_dropout=self.cfg.lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM, target_modules=targets, ) el = self.cfg.lora_edge_layers if el is not None: if not layerish: pass else: n = int(self.llm.config.num_hidden_layers) lt = lora_layer_indices_for_edges(n, int(el)) peft_kw["layers_to_transform"] = lt self.llm = get_peft_model(self.llm, LoraConfig(**peft_kw)) else: self.llm.eval() vision_dim = self._probe_vision_dim() llm_hidden = self.llm.config.hidden_size self.projector = MLPProjector( vision_dim=vision_dim, llm_hidden=llm_hidden, hidden_mult=self.cfg.projector_hidden_mult, num_intermediate=self.cfg.projector_num_intermediate, ).to(self.cfg.torch_dtype) def compile_components(self, mode: str = "default", dynamic: bool = True) -> "IrisVLM": """Wrap heavy submodules with torch.compile. Only the frozen vision + LM forwards matter for throughput; the projector is tiny so we skip it (also keeps its ``state_dict`` keys prefix-free for checkpoints). ``dynamic=True`` avoids recompiles when prompt / response lengths vary across batches; we also raise the dynamo cache limit to absorb the first few shape specializations. PEFT-wrapped LMs are not ``torch.compile``'d (unsupported / fragile). """ import torch._dynamo as _dynamo _dynamo.config.cache_size_limit = max(_dynamo.config.cache_size_limit, 64) self.vision = torch.compile(self.vision, mode=mode, dynamic=dynamic, fullgraph=False) if not self._use_lora: self.llm = torch.compile(self.llm, mode=mode, dynamic=dynamic, fullgraph=False) return self @staticmethod def _spatial_features(out) -> torch.Tensor: """Extract dense (B,T,D) features from a RADIO forward output.""" for name in ("features", "spatial_features"): v = getattr(out, name, None) if v is not None: return v if isinstance(out, tuple) and len(out) >= 2: return out[1] raise TypeError(f"Cannot find spatial features on RADIO output: {type(out)}") def _vision_autocast(self): """Autocast context matching cfg.torch_dtype (bf16/fp16), or no-op for fp32.""" dtype = self.cfg.torch_dtype if dtype in (torch.bfloat16, torch.float16): device = next(self.vision.parameters()).device return torch.autocast(device_type=device.type, dtype=dtype) import contextlib return contextlib.nullcontext() @torch.no_grad() def _probe_vision_dim(self) -> int: """Run RADIO once on a dummy image to discover spatial_features dim.""" device = next(self.vision.parameters()).device dummy = torch.zeros( 1, 3, self.cfg.image_size, self.cfg.image_size, device=device, dtype=torch.float32, ) with self._vision_autocast(): spatial = self._spatial_features(self.vision(dummy)) assert spatial.dim() == 3, f"expected (B,T,D), got {tuple(spatial.shape)}" return spatial.shape[-1] def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: """(B,3,H,W) -> (B,T,llm_hidden) visual tokens in LLM embedding space. RADIO runs in fp32 weights under autocast so its mixed fp32 buffers (input_conditioner) stay consistent; output is cast to the projector dtype before the trainable MLP. """ with torch.no_grad(), self._vision_autocast(): spatial = self._spatial_features(self.vision(pixel_values.float())) return self.projector( spatial.to(self.projector.net[0].weight.dtype) ) def _embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: return self.llm.get_input_embeddings()(input_ids) def forward( self, pixel_values: torch.Tensor, # (B, 3, H, W) prompt_ids: torch.Tensor | None = None, prompt_mask: torch.Tensor | None = None, response_ids: torch.Tensor | None = None, response_mask: torch.Tensor | None = None, packed_text_ids: torch.Tensor | None = None, packed_text_mask: torch.Tensor | None = None, packed_text_labels: torch.Tensor | None = None, **kwargs: Any, ) -> torch.Tensor: """Vision-first forward. **Single-turn (default):** ``[visual, prompt, response]`` — CE only on ``response`` tokens (see ``prompt_ids`` … ``response_mask``). **Multiturn:** pass ``packed_text_ids``, ``packed_text_mask``, and ``packed_text_labels`` instead of prompt/response tensors. Layout is ``[visual, packed_text]`` with ``packed_text_labels`` already ``-100`` on user spans and padding (loss on every assistant span in one sequence). Extra keyword args are ignored so callers can pass through batch dicts. """ _ = kwargs # allow batch dicts with unused keys B = pixel_values.size(0) device = pixel_values.device visual = self.encode_images(pixel_values) # (B, T, H) T = visual.size(1) if packed_text_ids is not None: if packed_text_mask is None or packed_text_labels is None: raise ValueError( "packed_text_ids requires packed_text_mask and packed_text_labels" ) text_emb = self._embed_tokens(packed_text_ids) inputs_embeds = torch.cat([visual, text_emb], dim=1) visual_mask = torch.ones(B, T, dtype=packed_text_mask.dtype, device=device) attention_mask = torch.cat([visual_mask, packed_text_mask], dim=1) ignore_vis = torch.full( (B, T), -100, dtype=torch.long, device=device ) labels = torch.cat([ignore_vis, packed_text_labels], dim=1) return self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, use_cache=False, ) if ( prompt_ids is None or prompt_mask is None or response_ids is None or response_mask is None ): raise ValueError( "Either packed_text_* (multiturn) or prompt_ids/response_ids (single-turn) is required" ) prompt_emb = self._embed_tokens(prompt_ids) # (B, Lp, H) resp_emb = self._embed_tokens(response_ids) # (B, Lr, H) inputs_embeds = torch.cat([visual, prompt_emb, resp_emb], dim=1) # (B, L, H) visual_mask = torch.ones(B, T, dtype=prompt_mask.dtype, device=device) attention_mask = torch.cat([visual_mask, prompt_mask, response_mask], dim=1) ignore = torch.full( (B, T + prompt_ids.size(1)), -100, dtype=torch.long, device=device ) resp_labels = response_ids.masked_fill(response_mask == 0, -100) labels = torch.cat([ignore, resp_labels], dim=1) return self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, use_cache=False, ) def trainable_parameters(self) -> list[torch.nn.Parameter]: out = [p for p in self.projector.parameters() if p.requires_grad] if self._use_lora: out.extend(p for p in self.llm.parameters() if p.requires_grad) return out