from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple import torch @dataclass class PrototypeDB: centers: torch.Tensor # [N,D] float32 labels: torch.Tensor # [N] int64 label_names: List[str] # id -> name source_path: Optional[Path] = None @property def dim(self) -> int: return int(self.centers.shape[1]) def id_to_name(self, idx: int) -> str: if 0 <= idx < len(self.label_names): return self.label_names[idx] return str(idx) def ensure_label_id(self, name: str) -> int: name = str(name).strip() if not name: raise ValueError("Empty label name.") try: i = self.label_names.index(name) return int(i) except ValueError: self.label_names.append(name) return len(self.label_names) - 1 def add_center(self, label_name: str, center: torch.Tensor) -> int: if center.ndim != 1: raise ValueError("center must be 1D embedding vector.") center = torch.nn.functional.normalize(center.float(), dim=0).view(1, -1) lid = self.ensure_label_id(label_name) self.centers = torch.cat([self.centers, center], dim=0) self.labels = torch.cat([self.labels, torch.tensor([lid], dtype=torch.long)], dim=0) return lid def save(self, path: Optional[str | Path] = None) -> Path: out = Path(path) if path is not None else self.source_path if out is None: raise ValueError("No output path specified for saving prototype DB.") out.parent.mkdir(parents=True, exist_ok=True) torch.save( dict( centers=self.centers.detach().cpu(), labels=self.labels.detach().cpu(), label_names=list(self.label_names), ), str(out), ) self.source_path = out return out def _infer_label_names_from_dataset(dataset_root: Path) -> Optional[List[str]]: # `train_style_ddp.TriViewDataset` assigns IDs based on sorted directory names under dataset/. if not dataset_root.exists(): return None artists = sorted([p.name for p in dataset_root.iterdir() if p.is_dir()]) return artists if artists else None def load_prototype_db(path: str | Path, *, try_dataset_dir: str | Path = "dataset") -> PrototypeDB: p = Path(path) if not p.exists(): raise FileNotFoundError(str(p)) obj = torch.load(str(p), map_location="cpu") if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj: raise ValueError(f"Unsupported prototype file format: {p}") centers = obj["centers"].float() labels = obj["labels"].long() label_names = obj.get("label_names") if not isinstance(label_names, list) or not all(isinstance(x, str) for x in label_names): inferred = _infer_label_names_from_dataset(Path(try_dataset_dir)) if inferred is None: max_id = int(labels.max().item()) if labels.numel() else -1 label_names = [str(i) for i in range(max_id + 1)] else: label_names = inferred return PrototypeDB(centers=centers, labels=labels, label_names=label_names, source_path=p) def topk_predictions( db: PrototypeDB, z: torch.Tensor, *, topk: int = 5, ) -> List[Tuple[str, float]]: """ Returns [(label_name, score)] sorted by score desc (cosine similarity). `z` is 1D embedding (D). """ if z.ndim != 1: raise ValueError("z must be 1D.") Z = torch.nn.functional.normalize(z.float(), dim=0).view(1, -1) C = torch.nn.functional.normalize(db.centers.float(), dim=1) sim = (Z @ C.t()).squeeze(0) # [N] k = int(max(1, min(topk, sim.numel()))) if sim.numel() else 0 if k == 0: return [] vals, idxs = torch.topk(sim, k=k) out: List[Tuple[str, float]] = [] for v, i in zip(vals.tolist(), idxs.tolist()): lid = int(db.labels[i].item()) out.append((db.id_to_name(lid), float(v))) return out def topk_predictions_unique_labels( db: PrototypeDB, z: torch.Tensor, *, topk: int = 5, ) -> List[Tuple[str, float]]: """ Like topk_predictions(), but dedupes by label: if a label has multiple prototypes, only the highest score is kept. """ if z.ndim != 1: raise ValueError("z must be 1D.") Z = torch.nn.functional.normalize(z.float(), dim=0).view(1, -1) C = torch.nn.functional.normalize(db.centers.float(), dim=1) sim = (Z @ C.t()).squeeze(0) # [N] if sim.numel() == 0: return [] best_by_label: dict[int, float] = {} # iterate all prototypes once; keep max per label id for i in range(sim.numel()): lid = int(db.labels[i].item()) s = float(sim[i].item()) prev = best_by_label.get(lid) if prev is None or s > prev: best_by_label[lid] = s items = sorted(best_by_label.items(), key=lambda kv: kv[1], reverse=True) items = items[: max(1, int(topk))] return [(db.id_to_name(lid), float(score)) for (lid, score) in items]