| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
| import torch |
| from torch import Tensor, nn |
| import torch.nn.functional as F |
| from PIL import Image |
|
|
| try: |
| from safetensors.torch import load_file as safe_load_file |
| from safetensors.torch import save_file as safe_save_file |
| except Exception: |
| safe_load_file = None |
| safe_save_file = None |
|
|
| try: |
| from transformers import AutoImageProcessor, AutoModel |
| except Exception: |
| AutoImageProcessor = None |
| AutoModel = None |
|
|
| from .config import ProtoMorphConfig |
| from .hf_utils import get_hf_token |
|
|
|
|
| @dataclass |
| class DinoFeatures: |
| cls: Tensor |
| registers: Optional[Tensor] |
| patches: Tensor |
| patch_hw: Tuple[int, int] |
| pixel_hw: Tuple[int, int] |
|
|
|
|
| class FrozenDINOv3(nn.Module): |
| """Hugging Face DINOv3 wrapper that returns CLS/register/patch tokens. |
| |
| DINOv3 is kept frozen. Use torch.autocast during forward for memory savings |
| on RTX 3090; the custom head remains regular PyTorch modules. |
| """ |
|
|
| def __init__(self, model_name: str, image_size: int = 512, local_files_only: bool = False): |
| super().__init__() |
| if AutoImageProcessor is None or AutoModel is None: |
| raise ImportError( |
| "transformers is required. Install transformers>=4.56.0 before loading DINOv3." |
| ) |
| self.model_name = model_name |
| self.image_size = image_size |
| hf_token = get_hf_token() |
| hf_kwargs = {"local_files_only": local_files_only} |
| if hf_token: |
| |
| hf_kwargs["token"] = hf_token |
| self.processor = AutoImageProcessor.from_pretrained(model_name, **hf_kwargs) |
| self.model = AutoModel.from_pretrained(model_name, **hf_kwargs) |
| self.model.eval().requires_grad_(False) |
|
|
| config = self.model.config |
| self.patch_size = int(getattr(config, "patch_size", 16)) |
| self.hidden_size = int(getattr(config, "hidden_size", 0)) |
| self.num_register_tokens = int(getattr(config, "num_register_tokens", 0)) |
|
|
| def _prepare_images(self, images: Image.Image | Sequence[Image.Image]) -> Dict[str, Tensor]: |
| if isinstance(images, Image.Image): |
| images = [images] |
| |
| |
| size = {"height": self.image_size, "width": self.image_size} |
| return self.processor(images=list(images), return_tensors="pt", size=size) |
|
|
| @torch.no_grad() |
| def forward(self, images: Image.Image | Sequence[Image.Image], device: torch.device | str) -> DinoFeatures: |
| inputs = self._prepare_images(images) |
| pixel_values = inputs["pixel_values"].to(device, non_blocking=True) |
| outputs = self.model(pixel_values=pixel_values) |
|
|
| tokens = outputs.last_hidden_state |
| cls = tokens[:, 0] |
| reg_start = 1 |
| reg_end = 1 + self.num_register_tokens |
| registers = tokens[:, reg_start:reg_end] if self.num_register_tokens > 0 else None |
| patches = tokens[:, reg_end:] |
|
|
| h, w = pixel_values.shape[-2:] |
| ph, pw = h // self.patch_size, w // self.patch_size |
| expected = ph * pw |
| if patches.shape[1] != expected: |
| |
| |
| side = int(patches.shape[1] ** 0.5) |
| if side * side == patches.shape[1]: |
| ph, pw = side, side |
| else: |
| ph, pw = patches.shape[1], 1 |
| return DinoFeatures(cls=cls, registers=registers, patches=patches, patch_hw=(ph, pw), pixel_hw=(h, w)) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim: int, expansion: int = 4, dropout: float = 0.0): |
| super().__init__() |
| hidden = dim * expansion |
| self.net = nn.Sequential( |
| nn.Linear(dim, hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.net(x) |
|
|
|
|
| class ProtoMorphBlock(nn.Module): |
| """Prototype-morphing residual block over DINO patch tokens. |
| |
| It computes soft assignment of each patch token to learnable prototypes, then |
| mixes original token, nearest prototype context, difference, and product. |
| This creates a lightweight nonstandard CNN replacement over patch embeddings. |
| """ |
|
|
| def __init__(self, dim: int, proto_count: int, dropout: float = 0.0): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim) |
| self.prototypes = nn.Parameter(torch.randn(proto_count, dim) * 0.02) |
| self.log_temperature = nn.Parameter(torch.tensor(0.0)) |
| self.mix = nn.Sequential( |
| nn.Linear(dim * 4, dim * 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim * 2, dim), |
| ) |
| self.gamma = nn.Parameter(torch.tensor(0.1)) |
| self.out_norm = nn.LayerNorm(dim) |
|
|
| def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: |
| zn = self.norm(z) |
| p = F.normalize(self.prototypes, dim=-1) |
| q = F.normalize(zn, dim=-1) |
| |
| dist = 1.0 - torch.matmul(q, p.t()) |
| temp = F.softplus(self.log_temperature) + 1e-4 |
| assign = F.softmax(-dist / temp, dim=-1) |
| context = torch.matmul(assign, self.prototypes) |
| mixed = self.mix(torch.cat([zn, context, zn - context, zn * context], dim=-1)) |
| z = z + self.gamma.tanh() * mixed |
| return self.out_norm(z), assign |
|
|
|
|
| class LayerMemoryAttention(nn.Module): |
| """A small learned memory bank attended by every patch token.""" |
|
|
| def __init__(self, dim: int, memory_tokens: int, num_heads: int, dropout: float = 0.0): |
| super().__init__() |
| self.memory = nn.Parameter(torch.randn(memory_tokens, dim) * 0.02) |
| self.norm_q = nn.LayerNorm(dim) |
| self.norm_out = nn.LayerNorm(dim) |
| self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=dropout, batch_first=True) |
| self.ffn = FeedForward(dim, expansion=4, dropout=dropout) |
| self.gamma_attn = nn.Parameter(torch.tensor(0.1)) |
| self.gamma_ffn = nn.Parameter(torch.tensor(0.1)) |
|
|
| def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: |
| b = z.shape[0] |
| mem = self.memory.unsqueeze(0).expand(b, -1, -1) |
| q = self.norm_q(z) |
| attn_out, attn_weights = self.attn(q, mem, mem, need_weights=True) |
| z = z + self.gamma_attn.tanh() * attn_out |
| z = z + self.gamma_ffn.tanh() * self.ffn(self.norm_out(z)) |
| return z, attn_weights |
|
|
|
|
| class MainClassifier(nn.Module): |
| def __init__(self, dim: int, num_classes: int, dropout: float = 0.0): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim * 3) |
| self.head = nn.Sequential( |
| nn.Linear(dim * 3, dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim, num_classes), |
| ) |
|
|
| def forward(self, cls: Tensor, z: Tensor) -> Tensor: |
| mean_pool = z.mean(dim=1) |
| max_pool = z.max(dim=1).values |
| feat = torch.cat([cls, mean_pool, max_pool], dim=-1) |
| return self.head(self.norm(feat)) |
|
|
|
|
| class Top2FeedbackModulator(nn.Module): |
| """Turns top-2 class probabilities into scale/shift over patch tokens.""" |
|
|
| def __init__(self, dim: int, num_classes: int): |
| super().__init__() |
| self.class_embed = nn.Embedding(num_classes, dim) |
| self.stats_mlp = nn.Sequential( |
| nn.Linear(4, dim), |
| nn.GELU(), |
| nn.Linear(dim, dim), |
| ) |
| self.to_scale_shift = nn.Sequential( |
| nn.LayerNorm(dim * 2), |
| nn.Linear(dim * 2, dim * 2), |
| ) |
|
|
| def forward(self, z0: Tensor, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: |
| probs = logits.softmax(dim=-1) |
| top_probs, top_idx = probs.topk(k=min(2, probs.shape[-1]), dim=-1) |
| if top_probs.shape[-1] == 1: |
| top_probs = torch.cat([top_probs, torch.zeros_like(top_probs)], dim=-1) |
| top_idx = torch.cat([top_idx, top_idx], dim=-1) |
|
|
| p1 = top_probs[:, 0] |
| p2 = top_probs[:, 1] |
| margin = p1 - p2 |
| entropy = -(probs * (probs.clamp_min(1e-8)).log()).sum(dim=-1) |
| class_vecs = self.class_embed(top_idx) |
| weighted_class_vec = (class_vecs * top_probs.unsqueeze(-1)).sum(dim=1) |
| stats = torch.stack([p1, p2, margin, entropy], dim=-1) |
| stat_vec = self.stats_mlp(stats) |
| scale_shift = self.to_scale_shift(torch.cat([weighted_class_vec, stat_vec], dim=-1)) |
| scale, shift = scale_shift.chunk(2, dim=-1) |
| z_mod = z0 * (1.0 + 0.25 * torch.tanh(scale).unsqueeze(1)) + 0.25 * torch.tanh(shift).unsqueeze(1) |
| return z_mod, { |
| "p1": p1, |
| "p2": p2, |
| "margin": margin, |
| "entropy": entropy, |
| "top_idx": top_idx, |
| "top_probs": top_probs, |
| } |
|
|
|
|
| class DeltaRBFHardExpert(nn.Module): |
| """RBF expert for hard examples, driven by feedback-modulated patch deltas.""" |
|
|
| def __init__(self, dim: int, rbf_count: int, num_classes: int, dropout: float = 0.0): |
| super().__init__() |
| self.delta_norm = nn.LayerNorm(dim) |
| self.rbf_centers = nn.Parameter(torch.randn(rbf_count, dim) * 0.02) |
| self.log_sigma = nn.Parameter(torch.zeros(rbf_count)) |
| self.rbf_to_logits = nn.Linear(rbf_count, num_classes) |
| self.delta_mlp = nn.Sequential( |
| nn.Linear(dim * 2, dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim, num_classes), |
| ) |
|
|
| def forward(self, z_base: Tensor, z_mod: Tensor) -> Tuple[Tensor, Tensor]: |
| delta = self.delta_norm(z_mod - z_base) |
| delta_mean = delta.mean(dim=1) |
| delta_max = delta.max(dim=1).values |
|
|
| q = F.normalize(delta, dim=-1) |
| c = F.normalize(self.rbf_centers, dim=-1) |
| dist = 1.0 - torch.matmul(q, c.t()) |
| sigma = F.softplus(self.log_sigma).view(1, 1, -1) + 1e-4 |
| rbf = torch.exp(-dist / sigma).mean(dim=1) |
| expert_logits = self.rbf_to_logits(rbf) + self.delta_mlp(torch.cat([delta_mean, delta_max], dim=-1)) |
| return expert_logits, rbf |
|
|
|
|
| class LogitFusion(nn.Module): |
| def __init__(self, num_classes: int): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.tensor(0.35)) |
| self.calibrate = nn.Sequential( |
| nn.LayerNorm(num_classes * 2), |
| nn.Linear(num_classes * 2, num_classes), |
| ) |
|
|
| def forward(self, main_logits: Tensor, expert_logits: Tensor) -> Tensor: |
| residual = self.calibrate(torch.cat([main_logits, expert_logits], dim=-1)) |
| return main_logits + self.alpha.sigmoid() * expert_logits + 0.1 * residual |
|
|
|
|
| class HardCaseGate(nn.Module): |
| """Deterministic inference gate from probability confidence signals.""" |
|
|
| def __init__(self, pmax_threshold: float, margin_threshold: float, entropy_threshold: float): |
| super().__init__() |
| self.pmax_threshold = pmax_threshold |
| self.margin_threshold = margin_threshold |
| self.entropy_threshold = entropy_threshold |
|
|
| def forward(self, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: |
| probs = logits.softmax(dim=-1) |
| top_probs = probs.topk(k=min(2, probs.shape[-1]), dim=-1).values |
| if top_probs.shape[-1] == 1: |
| p1 = top_probs[:, 0] |
| p2 = torch.zeros_like(p1) |
| else: |
| p1, p2 = top_probs[:, 0], top_probs[:, 1] |
| margin = p1 - p2 |
| entropy = -(probs * probs.clamp_min(1e-8).log()).sum(dim=-1) |
| hard = (p1 < self.pmax_threshold) | (margin < self.margin_threshold) | (entropy > self.entropy_threshold) |
| return hard, {"pmax": p1, "margin": margin, "entropy": entropy} |
|
|
|
|
| class ProtoMorphHead(nn.Module): |
| def __init__(self, cfg: ProtoMorphConfig): |
| super().__init__() |
| self.cfg = cfg |
| d = cfg.embed_dim |
| self.input_norm = nn.LayerNorm(d) |
| self.block1 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout) |
| self.mem1 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout) |
| self.block2 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout) |
| self.mem2 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout) |
| self.main = MainClassifier(d, cfg.num_classes, cfg.dropout) |
| self.gate = HardCaseGate(cfg.hard_pmax_threshold, cfg.hard_margin_threshold, cfg.hard_entropy_threshold) |
| self.feedback = Top2FeedbackModulator(d, cfg.num_classes) |
| self.hard_expert = DeltaRBFHardExpert(d, cfg.rbf_count, cfg.num_classes, cfg.dropout) |
| self.fusion = LogitFusion(cfg.num_classes) |
|
|
| def forward(self, cls: Tensor, patches: Tensor, force_hard: bool = False) -> Dict[str, Tensor]: |
| z0 = self.input_norm(patches) |
| z, assign1 = self.block1(z0) |
| z, mem_attn1 = self.mem1(z) |
| z, assign2 = self.block2(z) |
| z, mem_attn2 = self.mem2(z) |
|
|
| main_logits = self.main(cls, z) |
| hard_mask, gate_stats = self.gate(main_logits) |
| if force_hard: |
| hard_mask = torch.ones_like(hard_mask, dtype=torch.bool) |
|
|
| z_mod, fb_stats = self.feedback(z0, main_logits) |
| expert_logits, rbf = self.hard_expert(z0, z_mod) |
| fused_logits = self.fusion(main_logits, expert_logits) |
| final_logits = torch.where(hard_mask[:, None], fused_logits, main_logits) |
|
|
| out = { |
| "logits": final_logits, |
| "main_logits": main_logits, |
| "expert_logits": expert_logits, |
| "hard_mask": hard_mask, |
| "rbf": rbf, |
| "assign1_mean": assign1.mean(dim=1), |
| "assign2_mean": assign2.mean(dim=1), |
| "mem_attn1_mean": mem_attn1.mean(dim=1), |
| "mem_attn2_mean": mem_attn2.mean(dim=1), |
| } |
| out.update({f"gate_{k}": v for k, v in gate_stats.items()}) |
| out.update({f"fb_{k}": v for k, v in fb_stats.items() if isinstance(v, Tensor)}) |
| return out |
|
|
|
|
| class ProtoMorphDINOv3(nn.Module): |
| """Full inference graph: frozen DINOv3 + custom ProtoMorph head.""" |
|
|
| def __init__(self, cfg: ProtoMorphConfig, local_files_only: bool = False): |
| super().__init__() |
| self.cfg = cfg |
| self.backbone = FrozenDINOv3(cfg.dino_model_name, image_size=cfg.image_size, local_files_only=local_files_only) |
| actual_dim = self.backbone.hidden_size |
| if actual_dim and actual_dim != cfg.embed_dim: |
| raise ValueError( |
| f"Config embed_dim={cfg.embed_dim} but DINO hidden_size={actual_dim}. " |
| f"Use the matching config or run scripts/create_random_head.py with --embed-dim {actual_dim}." |
| ) |
| self.head = ProtoMorphHead(cfg) |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| images: Image.Image | Sequence[Image.Image], |
| device: torch.device | str, |
| force_hard: bool = False, |
| use_bf16_autocast: Optional[bool] = None, |
| ) -> Dict[str, Tensor | Tuple[int, int]]: |
| use_amp = self.cfg.use_bf16_autocast if use_bf16_autocast is None else use_bf16_autocast |
| device_obj = torch.device(device) |
| amp_enabled = bool(use_amp and device_obj.type == "cuda") |
| amp_dtype = torch.bfloat16 |
|
|
| with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp_enabled): |
| feats = self.backbone(images, device=device_obj) |
| cls = feats.cls |
| patches = feats.patches |
| if self.cfg.normalize_patch_tokens: |
| cls = F.layer_norm(cls, cls.shape[-1:]) |
| patches = F.layer_norm(patches, patches.shape[-1:]) |
| head_out = self.head(cls, patches, force_hard=force_hard) |
| head_out["patch_hw"] = feats.patch_hw |
| head_out["pixel_hw"] = feats.pixel_hw |
| return head_out |
|
|
| def save_custom_head(self, checkpoint_path: str | Path) -> None: |
| if safe_save_file is None: |
| raise ImportError("safetensors is required: pip install safetensors") |
| p = Path(checkpoint_path) |
| p.parent.mkdir(parents=True, exist_ok=True) |
| safe_save_file(self.head.state_dict(), str(p)) |
|
|
| def load_custom_head(self, checkpoint_path: str | Path, strict: bool = True) -> None: |
| if safe_load_file is None: |
| raise ImportError("safetensors is required: pip install safetensors") |
| sd = safe_load_file(str(checkpoint_path), device="cpu") |
| self.head.load_state_dict(sd, strict=strict) |
|
|
|
|
| def infer_embed_dim_from_model_name(model_name: str) -> int: |
| """Useful defaults for DINOv3 ViT checkpoints.""" |
| name = model_name.lower() |
| if "vits" in name: |
| return 384 |
| if "vitb" in name: |
| return 768 |
| if "vitl" in name: |
| return 1024 |
| if "vith" in name: |
| return 1280 |
| return 384 |
|
|