""" Persistent cache for teacher update-decisions. Used for two-stage training: Stage 1 (label generation) – teacher API is called once per unique frame pair, and the result is saved to disk. Stage 2 (training) – the cached decision is read directly, avoiding any online teacher API call during the forward pass. Keys are (seq_name, frame_id_A, frame_id_B) – the two template-candidate frames that Qwen compares. Values are Python bools, or ``null`` (JSON) when the teacher failed for that pair. The cache is persisted as a single JSON file. Because the file may be written by multiple DDP ranks, all writes use an atomic-rename pattern with fcntl locking (best-effort on the local filesystem). """ from __future__ import annotations import fcntl import json import os from typing import Dict, List, Optional, Tuple class TeacherLabelCache: """Thread/process-safe persistent cache for teacher update decisions. Usage:: cache = TeacherLabelCache("./output/teacher_cache") dec = cache.get("airplane-1", 120, 150) # → True / False / None cache.set("airplane-1", 120, 150, True) cache.save() """ def __init__(self, cache_dir: str): self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) self._cache_path = os.path.join(cache_dir, "teacher_labels.json") self._cache: Dict[str, Optional[bool]] = {} self._dirty = False self._load() # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ @staticmethod def make_key(seq_name: str, frame_a: int, frame_b: int) -> str: """Deterministic key; ORDER MATTERS. ``frame_a`` = old template frame (template[-2]) ``frame_b`` = new candidate frame (template[-1]) The teacher is asked: *should we update FROM frame_a TO frame_b?* This is a directional question, so the key preserves the order. """ fa = int(frame_a) fb = int(frame_b) return f"{seq_name}__{fa}__{fb}" def get(self, seq_name: str, frame_a: int, frame_b: int) -> Optional[bool]: """Return cached decision, or ``None`` on cache miss / teacher failure.""" return self._cache.get(self.make_key(seq_name, frame_a, frame_b)) def set(self, seq_name: str, frame_a: int, frame_b: int, decision: Optional[bool]): """Store a decision. ``decision`` may be ``None`` (= teacher failed).""" self._cache[self.make_key(seq_name, frame_a, frame_b)] = decision self._dirty = True def get_batch( self, seq_names: List[str], frame_ids_a: List[int], frame_ids_b: List[int], ) -> List[Optional[bool]]: """Look up a whole batch. Returns a list the same length as the inputs.""" return [ self.get(seq, int(fa), int(fb)) for seq, fa, fb in zip(seq_names, frame_ids_a, frame_ids_b) ] def set_batch( self, seq_names: List[str], frame_ids_a: List[int], frame_ids_b: List[int], decisions: List[Optional[bool]], ): """Store a whole batch.""" for seq, fa, fb, dec in zip(seq_names, frame_ids_a, frame_ids_b, decisions): self.set(seq, int(fa), int(fb), dec) def hit_rate(self) -> float: """Fraction of cache entries that are not ``None``.""" if not self._cache: return 0.0 return sum(1 for v in self._cache.values() if v is not None) / len(self._cache) # ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ def save(self): """Atomically write the cache to disk (if dirty).""" if not self._dirty: return tmp_path = self._cache_path + ".tmp" try: with open(tmp_path, "w") as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(self._cache, f, indent=2, sort_keys=True) fcntl.flock(f, fcntl.LOCK_UN) os.rename(tmp_path, self._cache_path) self._dirty = False except (IOError, OSError): # Non-critical – the in-memory cache is still valid; disk write # will be retried on the next ``save()``. pass def _load(self): if not os.path.exists(self._cache_path): self._cache = {} return try: with open(self._cache_path, "r") as f: self._cache = json.load(f) except (json.JSONDecodeError, IOError): self._cache = {} # ------------------------------------------------------------------ # Info # ------------------------------------------------------------------ def __len__(self) -> int: return len(self._cache) def __repr__(self) -> str: return ( f"TeacherLabelCache({len(self)} entries, " f"hit_rate={self.hit_rate():.1%}, " f"path={self._cache_path})" )