from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Tuple import torch from torch import Tensor, nn import torch.nn.functional as F from PIL import Image try: from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file except Exception: # pragma: no cover - handled at runtime with better error. safe_load_file = None safe_save_file = None try: from transformers import AutoImageProcessor, AutoModel except Exception: # pragma: no cover - handled at runtime with better error. AutoImageProcessor = None AutoModel = None from .config import ProtoMorphConfig from .hf_utils import get_hf_token @dataclass class DinoFeatures: cls: Tensor registers: Optional[Tensor] patches: Tensor patch_hw: Tuple[int, int] pixel_hw: Tuple[int, int] class FrozenDINOv3(nn.Module): """Hugging Face DINOv3 wrapper that returns CLS/register/patch tokens. DINOv3 is kept frozen. Use torch.autocast during forward for memory savings on RTX 3090; the custom head remains regular PyTorch modules. """ def __init__(self, model_name: str, image_size: int = 512, local_files_only: bool = False): super().__init__() if AutoImageProcessor is None or AutoModel is None: raise ImportError( "transformers is required. Install transformers>=4.56.0 before loading DINOv3." ) self.model_name = model_name self.image_size = image_size hf_token = get_hf_token() hf_kwargs = {"local_files_only": local_files_only} if hf_token: # Supports RunPod env variable `hf_key` as well as standard HF_TOKEN. hf_kwargs["token"] = hf_token self.processor = AutoImageProcessor.from_pretrained(model_name, **hf_kwargs) self.model = AutoModel.from_pretrained(model_name, **hf_kwargs) self.model.eval().requires_grad_(False) config = self.model.config self.patch_size = int(getattr(config, "patch_size", 16)) self.hidden_size = int(getattr(config, "hidden_size", 0)) self.num_register_tokens = int(getattr(config, "num_register_tokens", 0)) def _prepare_images(self, images: Image.Image | Sequence[Image.Image]) -> Dict[str, Tensor]: if isinstance(images, Image.Image): images = [images] # HF processors support overriding target size at call time for ViT-like image processors. # We request a square size that is divisible by patch_size for clean patch grids. size = {"height": self.image_size, "width": self.image_size} return self.processor(images=list(images), return_tensors="pt", size=size) @torch.no_grad() def forward(self, images: Image.Image | Sequence[Image.Image], device: torch.device | str) -> DinoFeatures: inputs = self._prepare_images(images) pixel_values = inputs["pixel_values"].to(device, non_blocking=True) outputs = self.model(pixel_values=pixel_values) tokens = outputs.last_hidden_state cls = tokens[:, 0] reg_start = 1 reg_end = 1 + self.num_register_tokens registers = tokens[:, reg_start:reg_end] if self.num_register_tokens > 0 else None patches = tokens[:, reg_end:] h, w = pixel_values.shape[-2:] ph, pw = h // self.patch_size, w // self.patch_size expected = ph * pw if patches.shape[1] != expected: # Fallback for processors/checkpoints that return a non-square crop or resize. # This keeps inference running and makes the mismatch visible to the caller. side = int(patches.shape[1] ** 0.5) if side * side == patches.shape[1]: ph, pw = side, side else: ph, pw = patches.shape[1], 1 return DinoFeatures(cls=cls, registers=registers, patches=patches, patch_hw=(ph, pw), pixel_hw=(h, w)) class FeedForward(nn.Module): def __init__(self, dim: int, expansion: int = 4, dropout: float = 0.0): super().__init__() hidden = dim * expansion self.net = nn.Sequential( nn.Linear(dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, dim), nn.Dropout(dropout), ) def forward(self, x: Tensor) -> Tensor: return self.net(x) class ProtoMorphBlock(nn.Module): """Prototype-morphing residual block over DINO patch tokens. It computes soft assignment of each patch token to learnable prototypes, then mixes original token, nearest prototype context, difference, and product. This creates a lightweight nonstandard CNN replacement over patch embeddings. """ def __init__(self, dim: int, proto_count: int, dropout: float = 0.0): super().__init__() self.norm = nn.LayerNorm(dim) self.prototypes = nn.Parameter(torch.randn(proto_count, dim) * 0.02) self.log_temperature = nn.Parameter(torch.tensor(0.0)) self.mix = nn.Sequential( nn.Linear(dim * 4, dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 2, dim), ) self.gamma = nn.Parameter(torch.tensor(0.1)) self.out_norm = nn.LayerNorm(dim) def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: zn = self.norm(z) p = F.normalize(self.prototypes, dim=-1) q = F.normalize(zn, dim=-1) # cosine distance in [0, 2] dist = 1.0 - torch.matmul(q, p.t()) temp = F.softplus(self.log_temperature) + 1e-4 assign = F.softmax(-dist / temp, dim=-1) context = torch.matmul(assign, self.prototypes) mixed = self.mix(torch.cat([zn, context, zn - context, zn * context], dim=-1)) z = z + self.gamma.tanh() * mixed return self.out_norm(z), assign class LayerMemoryAttention(nn.Module): """A small learned memory bank attended by every patch token.""" def __init__(self, dim: int, memory_tokens: int, num_heads: int, dropout: float = 0.0): super().__init__() self.memory = nn.Parameter(torch.randn(memory_tokens, dim) * 0.02) self.norm_q = nn.LayerNorm(dim) self.norm_out = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=dropout, batch_first=True) self.ffn = FeedForward(dim, expansion=4, dropout=dropout) self.gamma_attn = nn.Parameter(torch.tensor(0.1)) self.gamma_ffn = nn.Parameter(torch.tensor(0.1)) def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: b = z.shape[0] mem = self.memory.unsqueeze(0).expand(b, -1, -1) q = self.norm_q(z) attn_out, attn_weights = self.attn(q, mem, mem, need_weights=True) z = z + self.gamma_attn.tanh() * attn_out z = z + self.gamma_ffn.tanh() * self.ffn(self.norm_out(z)) return z, attn_weights class MainClassifier(nn.Module): def __init__(self, dim: int, num_classes: int, dropout: float = 0.0): super().__init__() self.norm = nn.LayerNorm(dim * 3) self.head = nn.Sequential( nn.Linear(dim * 3, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, num_classes), ) def forward(self, cls: Tensor, z: Tensor) -> Tensor: mean_pool = z.mean(dim=1) max_pool = z.max(dim=1).values feat = torch.cat([cls, mean_pool, max_pool], dim=-1) return self.head(self.norm(feat)) class Top2FeedbackModulator(nn.Module): """Turns top-2 class probabilities into scale/shift over patch tokens.""" def __init__(self, dim: int, num_classes: int): super().__init__() self.class_embed = nn.Embedding(num_classes, dim) self.stats_mlp = nn.Sequential( nn.Linear(4, dim), nn.GELU(), nn.Linear(dim, dim), ) self.to_scale_shift = nn.Sequential( nn.LayerNorm(dim * 2), nn.Linear(dim * 2, dim * 2), ) def forward(self, z0: Tensor, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: probs = logits.softmax(dim=-1) top_probs, top_idx = probs.topk(k=min(2, probs.shape[-1]), dim=-1) if top_probs.shape[-1] == 1: top_probs = torch.cat([top_probs, torch.zeros_like(top_probs)], dim=-1) top_idx = torch.cat([top_idx, top_idx], dim=-1) p1 = top_probs[:, 0] p2 = top_probs[:, 1] margin = p1 - p2 entropy = -(probs * (probs.clamp_min(1e-8)).log()).sum(dim=-1) class_vecs = self.class_embed(top_idx) # [B, 2, C] weighted_class_vec = (class_vecs * top_probs.unsqueeze(-1)).sum(dim=1) stats = torch.stack([p1, p2, margin, entropy], dim=-1) stat_vec = self.stats_mlp(stats) scale_shift = self.to_scale_shift(torch.cat([weighted_class_vec, stat_vec], dim=-1)) scale, shift = scale_shift.chunk(2, dim=-1) z_mod = z0 * (1.0 + 0.25 * torch.tanh(scale).unsqueeze(1)) + 0.25 * torch.tanh(shift).unsqueeze(1) return z_mod, { "p1": p1, "p2": p2, "margin": margin, "entropy": entropy, "top_idx": top_idx, "top_probs": top_probs, } class DeltaRBFHardExpert(nn.Module): """RBF expert for hard examples, driven by feedback-modulated patch deltas.""" def __init__(self, dim: int, rbf_count: int, num_classes: int, dropout: float = 0.0): super().__init__() self.delta_norm = nn.LayerNorm(dim) self.rbf_centers = nn.Parameter(torch.randn(rbf_count, dim) * 0.02) self.log_sigma = nn.Parameter(torch.zeros(rbf_count)) self.rbf_to_logits = nn.Linear(rbf_count, num_classes) self.delta_mlp = nn.Sequential( nn.Linear(dim * 2, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, num_classes), ) def forward(self, z_base: Tensor, z_mod: Tensor) -> Tuple[Tensor, Tensor]: delta = self.delta_norm(z_mod - z_base) delta_mean = delta.mean(dim=1) delta_max = delta.max(dim=1).values q = F.normalize(delta, dim=-1) c = F.normalize(self.rbf_centers, dim=-1) dist = 1.0 - torch.matmul(q, c.t()) # [B, N, R] sigma = F.softplus(self.log_sigma).view(1, 1, -1) + 1e-4 rbf = torch.exp(-dist / sigma).mean(dim=1) # [B, R] expert_logits = self.rbf_to_logits(rbf) + self.delta_mlp(torch.cat([delta_mean, delta_max], dim=-1)) return expert_logits, rbf class LogitFusion(nn.Module): def __init__(self, num_classes: int): super().__init__() self.alpha = nn.Parameter(torch.tensor(0.35)) self.calibrate = nn.Sequential( nn.LayerNorm(num_classes * 2), nn.Linear(num_classes * 2, num_classes), ) def forward(self, main_logits: Tensor, expert_logits: Tensor) -> Tensor: residual = self.calibrate(torch.cat([main_logits, expert_logits], dim=-1)) return main_logits + self.alpha.sigmoid() * expert_logits + 0.1 * residual class HardCaseGate(nn.Module): """Deterministic inference gate from probability confidence signals.""" def __init__(self, pmax_threshold: float, margin_threshold: float, entropy_threshold: float): super().__init__() self.pmax_threshold = pmax_threshold self.margin_threshold = margin_threshold self.entropy_threshold = entropy_threshold def forward(self, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: probs = logits.softmax(dim=-1) top_probs = probs.topk(k=min(2, probs.shape[-1]), dim=-1).values if top_probs.shape[-1] == 1: p1 = top_probs[:, 0] p2 = torch.zeros_like(p1) else: p1, p2 = top_probs[:, 0], top_probs[:, 1] margin = p1 - p2 entropy = -(probs * probs.clamp_min(1e-8).log()).sum(dim=-1) hard = (p1 < self.pmax_threshold) | (margin < self.margin_threshold) | (entropy > self.entropy_threshold) return hard, {"pmax": p1, "margin": margin, "entropy": entropy} class ProtoMorphHead(nn.Module): def __init__(self, cfg: ProtoMorphConfig): super().__init__() self.cfg = cfg d = cfg.embed_dim self.input_norm = nn.LayerNorm(d) self.block1 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout) self.mem1 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout) self.block2 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout) self.mem2 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout) self.main = MainClassifier(d, cfg.num_classes, cfg.dropout) self.gate = HardCaseGate(cfg.hard_pmax_threshold, cfg.hard_margin_threshold, cfg.hard_entropy_threshold) self.feedback = Top2FeedbackModulator(d, cfg.num_classes) self.hard_expert = DeltaRBFHardExpert(d, cfg.rbf_count, cfg.num_classes, cfg.dropout) self.fusion = LogitFusion(cfg.num_classes) def forward(self, cls: Tensor, patches: Tensor, force_hard: bool = False) -> Dict[str, Tensor]: z0 = self.input_norm(patches) z, assign1 = self.block1(z0) z, mem_attn1 = self.mem1(z) z, assign2 = self.block2(z) z, mem_attn2 = self.mem2(z) main_logits = self.main(cls, z) hard_mask, gate_stats = self.gate(main_logits) if force_hard: hard_mask = torch.ones_like(hard_mask, dtype=torch.bool) z_mod, fb_stats = self.feedback(z0, main_logits) expert_logits, rbf = self.hard_expert(z0, z_mod) fused_logits = self.fusion(main_logits, expert_logits) final_logits = torch.where(hard_mask[:, None], fused_logits, main_logits) out = { "logits": final_logits, "main_logits": main_logits, "expert_logits": expert_logits, "hard_mask": hard_mask, "rbf": rbf, "assign1_mean": assign1.mean(dim=1), "assign2_mean": assign2.mean(dim=1), "mem_attn1_mean": mem_attn1.mean(dim=1), "mem_attn2_mean": mem_attn2.mean(dim=1), } out.update({f"gate_{k}": v for k, v in gate_stats.items()}) out.update({f"fb_{k}": v for k, v in fb_stats.items() if isinstance(v, Tensor)}) return out class ProtoMorphDINOv3(nn.Module): """Full inference graph: frozen DINOv3 + custom ProtoMorph head.""" def __init__(self, cfg: ProtoMorphConfig, local_files_only: bool = False): super().__init__() self.cfg = cfg self.backbone = FrozenDINOv3(cfg.dino_model_name, image_size=cfg.image_size, local_files_only=local_files_only) actual_dim = self.backbone.hidden_size if actual_dim and actual_dim != cfg.embed_dim: raise ValueError( f"Config embed_dim={cfg.embed_dim} but DINO hidden_size={actual_dim}. " f"Use the matching config or run scripts/create_random_head.py with --embed-dim {actual_dim}." ) self.head = ProtoMorphHead(cfg) @torch.no_grad() def forward( self, images: Image.Image | Sequence[Image.Image], device: torch.device | str, force_hard: bool = False, use_bf16_autocast: Optional[bool] = None, ) -> Dict[str, Tensor | Tuple[int, int]]: use_amp = self.cfg.use_bf16_autocast if use_bf16_autocast is None else use_bf16_autocast device_obj = torch.device(device) amp_enabled = bool(use_amp and device_obj.type == "cuda") amp_dtype = torch.bfloat16 with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp_enabled): feats = self.backbone(images, device=device_obj) cls = feats.cls patches = feats.patches if self.cfg.normalize_patch_tokens: cls = F.layer_norm(cls, cls.shape[-1:]) patches = F.layer_norm(patches, patches.shape[-1:]) head_out = self.head(cls, patches, force_hard=force_hard) head_out["patch_hw"] = feats.patch_hw head_out["pixel_hw"] = feats.pixel_hw return head_out def save_custom_head(self, checkpoint_path: str | Path) -> None: if safe_save_file is None: raise ImportError("safetensors is required: pip install safetensors") p = Path(checkpoint_path) p.parent.mkdir(parents=True, exist_ok=True) safe_save_file(self.head.state_dict(), str(p)) def load_custom_head(self, checkpoint_path: str | Path, strict: bool = True) -> None: if safe_load_file is None: raise ImportError("safetensors is required: pip install safetensors") sd = safe_load_file(str(checkpoint_path), device="cpu") self.head.load_state_dict(sd, strict=strict) def infer_embed_dim_from_model_name(model_name: str) -> int: """Useful defaults for DINOv3 ViT checkpoints.""" name = model_name.lower() if "vits" in name: return 384 if "vitb" in name: return 768 if "vitl" in name: return 1024 if "vith" in name: return 1280 return 384