shiowo's picture
Upload ProtoMorph-DINO scaffold and random head checkpoint
63089c1 verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from PIL import Image
try:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
except Exception: # pragma: no cover - handled at runtime with better error.
safe_load_file = None
safe_save_file = None
try:
from transformers import AutoImageProcessor, AutoModel
except Exception: # pragma: no cover - handled at runtime with better error.
AutoImageProcessor = None
AutoModel = None
from .config import ProtoMorphConfig
from .hf_utils import get_hf_token
@dataclass
class DinoFeatures:
cls: Tensor
registers: Optional[Tensor]
patches: Tensor
patch_hw: Tuple[int, int]
pixel_hw: Tuple[int, int]
class FrozenDINOv3(nn.Module):
"""Hugging Face DINOv3 wrapper that returns CLS/register/patch tokens.
DINOv3 is kept frozen. Use torch.autocast during forward for memory savings
on RTX 3090; the custom head remains regular PyTorch modules.
"""
def __init__(self, model_name: str, image_size: int = 512, local_files_only: bool = False):
super().__init__()
if AutoImageProcessor is None or AutoModel is None:
raise ImportError(
"transformers is required. Install transformers>=4.56.0 before loading DINOv3."
)
self.model_name = model_name
self.image_size = image_size
hf_token = get_hf_token()
hf_kwargs = {"local_files_only": local_files_only}
if hf_token:
# Supports RunPod env variable `hf_key` as well as standard HF_TOKEN.
hf_kwargs["token"] = hf_token
self.processor = AutoImageProcessor.from_pretrained(model_name, **hf_kwargs)
self.model = AutoModel.from_pretrained(model_name, **hf_kwargs)
self.model.eval().requires_grad_(False)
config = self.model.config
self.patch_size = int(getattr(config, "patch_size", 16))
self.hidden_size = int(getattr(config, "hidden_size", 0))
self.num_register_tokens = int(getattr(config, "num_register_tokens", 0))
def _prepare_images(self, images: Image.Image | Sequence[Image.Image]) -> Dict[str, Tensor]:
if isinstance(images, Image.Image):
images = [images]
# HF processors support overriding target size at call time for ViT-like image processors.
# We request a square size that is divisible by patch_size for clean patch grids.
size = {"height": self.image_size, "width": self.image_size}
return self.processor(images=list(images), return_tensors="pt", size=size)
@torch.no_grad()
def forward(self, images: Image.Image | Sequence[Image.Image], device: torch.device | str) -> DinoFeatures:
inputs = self._prepare_images(images)
pixel_values = inputs["pixel_values"].to(device, non_blocking=True)
outputs = self.model(pixel_values=pixel_values)
tokens = outputs.last_hidden_state
cls = tokens[:, 0]
reg_start = 1
reg_end = 1 + self.num_register_tokens
registers = tokens[:, reg_start:reg_end] if self.num_register_tokens > 0 else None
patches = tokens[:, reg_end:]
h, w = pixel_values.shape[-2:]
ph, pw = h // self.patch_size, w // self.patch_size
expected = ph * pw
if patches.shape[1] != expected:
# Fallback for processors/checkpoints that return a non-square crop or resize.
# This keeps inference running and makes the mismatch visible to the caller.
side = int(patches.shape[1] ** 0.5)
if side * side == patches.shape[1]:
ph, pw = side, side
else:
ph, pw = patches.shape[1], 1
return DinoFeatures(cls=cls, registers=registers, patches=patches, patch_hw=(ph, pw), pixel_hw=(h, w))
class FeedForward(nn.Module):
def __init__(self, dim: int, expansion: int = 4, dropout: float = 0.0):
super().__init__()
hidden = dim * expansion
self.net = nn.Sequential(
nn.Linear(dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, dim),
nn.Dropout(dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
class ProtoMorphBlock(nn.Module):
"""Prototype-morphing residual block over DINO patch tokens.
It computes soft assignment of each patch token to learnable prototypes, then
mixes original token, nearest prototype context, difference, and product.
This creates a lightweight nonstandard CNN replacement over patch embeddings.
"""
def __init__(self, dim: int, proto_count: int, dropout: float = 0.0):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.prototypes = nn.Parameter(torch.randn(proto_count, dim) * 0.02)
self.log_temperature = nn.Parameter(torch.tensor(0.0))
self.mix = nn.Sequential(
nn.Linear(dim * 4, dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 2, dim),
)
self.gamma = nn.Parameter(torch.tensor(0.1))
self.out_norm = nn.LayerNorm(dim)
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
zn = self.norm(z)
p = F.normalize(self.prototypes, dim=-1)
q = F.normalize(zn, dim=-1)
# cosine distance in [0, 2]
dist = 1.0 - torch.matmul(q, p.t())
temp = F.softplus(self.log_temperature) + 1e-4
assign = F.softmax(-dist / temp, dim=-1)
context = torch.matmul(assign, self.prototypes)
mixed = self.mix(torch.cat([zn, context, zn - context, zn * context], dim=-1))
z = z + self.gamma.tanh() * mixed
return self.out_norm(z), assign
class LayerMemoryAttention(nn.Module):
"""A small learned memory bank attended by every patch token."""
def __init__(self, dim: int, memory_tokens: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.memory = nn.Parameter(torch.randn(memory_tokens, dim) * 0.02)
self.norm_q = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=dropout, batch_first=True)
self.ffn = FeedForward(dim, expansion=4, dropout=dropout)
self.gamma_attn = nn.Parameter(torch.tensor(0.1))
self.gamma_ffn = nn.Parameter(torch.tensor(0.1))
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
b = z.shape[0]
mem = self.memory.unsqueeze(0).expand(b, -1, -1)
q = self.norm_q(z)
attn_out, attn_weights = self.attn(q, mem, mem, need_weights=True)
z = z + self.gamma_attn.tanh() * attn_out
z = z + self.gamma_ffn.tanh() * self.ffn(self.norm_out(z))
return z, attn_weights
class MainClassifier(nn.Module):
def __init__(self, dim: int, num_classes: int, dropout: float = 0.0):
super().__init__()
self.norm = nn.LayerNorm(dim * 3)
self.head = nn.Sequential(
nn.Linear(dim * 3, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, num_classes),
)
def forward(self, cls: Tensor, z: Tensor) -> Tensor:
mean_pool = z.mean(dim=1)
max_pool = z.max(dim=1).values
feat = torch.cat([cls, mean_pool, max_pool], dim=-1)
return self.head(self.norm(feat))
class Top2FeedbackModulator(nn.Module):
"""Turns top-2 class probabilities into scale/shift over patch tokens."""
def __init__(self, dim: int, num_classes: int):
super().__init__()
self.class_embed = nn.Embedding(num_classes, dim)
self.stats_mlp = nn.Sequential(
nn.Linear(4, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
self.to_scale_shift = nn.Sequential(
nn.LayerNorm(dim * 2),
nn.Linear(dim * 2, dim * 2),
)
def forward(self, z0: Tensor, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
probs = logits.softmax(dim=-1)
top_probs, top_idx = probs.topk(k=min(2, probs.shape[-1]), dim=-1)
if top_probs.shape[-1] == 1:
top_probs = torch.cat([top_probs, torch.zeros_like(top_probs)], dim=-1)
top_idx = torch.cat([top_idx, top_idx], dim=-1)
p1 = top_probs[:, 0]
p2 = top_probs[:, 1]
margin = p1 - p2
entropy = -(probs * (probs.clamp_min(1e-8)).log()).sum(dim=-1)
class_vecs = self.class_embed(top_idx) # [B, 2, C]
weighted_class_vec = (class_vecs * top_probs.unsqueeze(-1)).sum(dim=1)
stats = torch.stack([p1, p2, margin, entropy], dim=-1)
stat_vec = self.stats_mlp(stats)
scale_shift = self.to_scale_shift(torch.cat([weighted_class_vec, stat_vec], dim=-1))
scale, shift = scale_shift.chunk(2, dim=-1)
z_mod = z0 * (1.0 + 0.25 * torch.tanh(scale).unsqueeze(1)) + 0.25 * torch.tanh(shift).unsqueeze(1)
return z_mod, {
"p1": p1,
"p2": p2,
"margin": margin,
"entropy": entropy,
"top_idx": top_idx,
"top_probs": top_probs,
}
class DeltaRBFHardExpert(nn.Module):
"""RBF expert for hard examples, driven by feedback-modulated patch deltas."""
def __init__(self, dim: int, rbf_count: int, num_classes: int, dropout: float = 0.0):
super().__init__()
self.delta_norm = nn.LayerNorm(dim)
self.rbf_centers = nn.Parameter(torch.randn(rbf_count, dim) * 0.02)
self.log_sigma = nn.Parameter(torch.zeros(rbf_count))
self.rbf_to_logits = nn.Linear(rbf_count, num_classes)
self.delta_mlp = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, num_classes),
)
def forward(self, z_base: Tensor, z_mod: Tensor) -> Tuple[Tensor, Tensor]:
delta = self.delta_norm(z_mod - z_base)
delta_mean = delta.mean(dim=1)
delta_max = delta.max(dim=1).values
q = F.normalize(delta, dim=-1)
c = F.normalize(self.rbf_centers, dim=-1)
dist = 1.0 - torch.matmul(q, c.t()) # [B, N, R]
sigma = F.softplus(self.log_sigma).view(1, 1, -1) + 1e-4
rbf = torch.exp(-dist / sigma).mean(dim=1) # [B, R]
expert_logits = self.rbf_to_logits(rbf) + self.delta_mlp(torch.cat([delta_mean, delta_max], dim=-1))
return expert_logits, rbf
class LogitFusion(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(0.35))
self.calibrate = nn.Sequential(
nn.LayerNorm(num_classes * 2),
nn.Linear(num_classes * 2, num_classes),
)
def forward(self, main_logits: Tensor, expert_logits: Tensor) -> Tensor:
residual = self.calibrate(torch.cat([main_logits, expert_logits], dim=-1))
return main_logits + self.alpha.sigmoid() * expert_logits + 0.1 * residual
class HardCaseGate(nn.Module):
"""Deterministic inference gate from probability confidence signals."""
def __init__(self, pmax_threshold: float, margin_threshold: float, entropy_threshold: float):
super().__init__()
self.pmax_threshold = pmax_threshold
self.margin_threshold = margin_threshold
self.entropy_threshold = entropy_threshold
def forward(self, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
probs = logits.softmax(dim=-1)
top_probs = probs.topk(k=min(2, probs.shape[-1]), dim=-1).values
if top_probs.shape[-1] == 1:
p1 = top_probs[:, 0]
p2 = torch.zeros_like(p1)
else:
p1, p2 = top_probs[:, 0], top_probs[:, 1]
margin = p1 - p2
entropy = -(probs * probs.clamp_min(1e-8).log()).sum(dim=-1)
hard = (p1 < self.pmax_threshold) | (margin < self.margin_threshold) | (entropy > self.entropy_threshold)
return hard, {"pmax": p1, "margin": margin, "entropy": entropy}
class ProtoMorphHead(nn.Module):
def __init__(self, cfg: ProtoMorphConfig):
super().__init__()
self.cfg = cfg
d = cfg.embed_dim
self.input_norm = nn.LayerNorm(d)
self.block1 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout)
self.mem1 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout)
self.block2 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout)
self.mem2 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout)
self.main = MainClassifier(d, cfg.num_classes, cfg.dropout)
self.gate = HardCaseGate(cfg.hard_pmax_threshold, cfg.hard_margin_threshold, cfg.hard_entropy_threshold)
self.feedback = Top2FeedbackModulator(d, cfg.num_classes)
self.hard_expert = DeltaRBFHardExpert(d, cfg.rbf_count, cfg.num_classes, cfg.dropout)
self.fusion = LogitFusion(cfg.num_classes)
def forward(self, cls: Tensor, patches: Tensor, force_hard: bool = False) -> Dict[str, Tensor]:
z0 = self.input_norm(patches)
z, assign1 = self.block1(z0)
z, mem_attn1 = self.mem1(z)
z, assign2 = self.block2(z)
z, mem_attn2 = self.mem2(z)
main_logits = self.main(cls, z)
hard_mask, gate_stats = self.gate(main_logits)
if force_hard:
hard_mask = torch.ones_like(hard_mask, dtype=torch.bool)
z_mod, fb_stats = self.feedback(z0, main_logits)
expert_logits, rbf = self.hard_expert(z0, z_mod)
fused_logits = self.fusion(main_logits, expert_logits)
final_logits = torch.where(hard_mask[:, None], fused_logits, main_logits)
out = {
"logits": final_logits,
"main_logits": main_logits,
"expert_logits": expert_logits,
"hard_mask": hard_mask,
"rbf": rbf,
"assign1_mean": assign1.mean(dim=1),
"assign2_mean": assign2.mean(dim=1),
"mem_attn1_mean": mem_attn1.mean(dim=1),
"mem_attn2_mean": mem_attn2.mean(dim=1),
}
out.update({f"gate_{k}": v for k, v in gate_stats.items()})
out.update({f"fb_{k}": v for k, v in fb_stats.items() if isinstance(v, Tensor)})
return out
class ProtoMorphDINOv3(nn.Module):
"""Full inference graph: frozen DINOv3 + custom ProtoMorph head."""
def __init__(self, cfg: ProtoMorphConfig, local_files_only: bool = False):
super().__init__()
self.cfg = cfg
self.backbone = FrozenDINOv3(cfg.dino_model_name, image_size=cfg.image_size, local_files_only=local_files_only)
actual_dim = self.backbone.hidden_size
if actual_dim and actual_dim != cfg.embed_dim:
raise ValueError(
f"Config embed_dim={cfg.embed_dim} but DINO hidden_size={actual_dim}. "
f"Use the matching config or run scripts/create_random_head.py with --embed-dim {actual_dim}."
)
self.head = ProtoMorphHead(cfg)
@torch.no_grad()
def forward(
self,
images: Image.Image | Sequence[Image.Image],
device: torch.device | str,
force_hard: bool = False,
use_bf16_autocast: Optional[bool] = None,
) -> Dict[str, Tensor | Tuple[int, int]]:
use_amp = self.cfg.use_bf16_autocast if use_bf16_autocast is None else use_bf16_autocast
device_obj = torch.device(device)
amp_enabled = bool(use_amp and device_obj.type == "cuda")
amp_dtype = torch.bfloat16
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp_enabled):
feats = self.backbone(images, device=device_obj)
cls = feats.cls
patches = feats.patches
if self.cfg.normalize_patch_tokens:
cls = F.layer_norm(cls, cls.shape[-1:])
patches = F.layer_norm(patches, patches.shape[-1:])
head_out = self.head(cls, patches, force_hard=force_hard)
head_out["patch_hw"] = feats.patch_hw
head_out["pixel_hw"] = feats.pixel_hw
return head_out
def save_custom_head(self, checkpoint_path: str | Path) -> None:
if safe_save_file is None:
raise ImportError("safetensors is required: pip install safetensors")
p = Path(checkpoint_path)
p.parent.mkdir(parents=True, exist_ok=True)
safe_save_file(self.head.state_dict(), str(p))
def load_custom_head(self, checkpoint_path: str | Path, strict: bool = True) -> None:
if safe_load_file is None:
raise ImportError("safetensors is required: pip install safetensors")
sd = safe_load_file(str(checkpoint_path), device="cpu")
self.head.load_state_dict(sd, strict=strict)
def infer_embed_dim_from_model_name(model_name: str) -> int:
"""Useful defaults for DINOv3 ViT checkpoints."""
name = model_name.lower()
if "vits" in name:
return 384
if "vitb" in name:
return 768
if "vitl" in name:
return 1024
if "vith" in name:
return 1280
return 384