| """ |
| Persistent cache for teacher update-decisions. |
| |
| Used for two-stage training: |
| Stage 1 (label generation) – teacher API is called once per unique frame pair, |
| and the result is saved to disk. |
| Stage 2 (training) – the cached decision is read directly, avoiding any |
| online teacher API call during the forward pass. |
| |
| Keys are (seq_name, frame_id_A, frame_id_B) – the two template-candidate frames |
| that Qwen compares. Values are Python bools, or ``null`` (JSON) when the teacher |
| failed for that pair. |
| |
| The cache is persisted as a single JSON file. Because the file may be written |
| by multiple DDP ranks, all writes use an atomic-rename pattern with fcntl |
| locking (best-effort on the local filesystem). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import fcntl |
| import json |
| import os |
| from typing import Dict, List, Optional, Tuple |
|
|
|
|
| class TeacherLabelCache: |
| """Thread/process-safe persistent cache for teacher update decisions. |
| |
| Usage:: |
| |
| cache = TeacherLabelCache("./output/teacher_cache") |
| dec = cache.get("airplane-1", 120, 150) # → True / False / None |
| cache.set("airplane-1", 120, 150, True) |
| cache.save() |
| """ |
|
|
| def __init__(self, cache_dir: str): |
| self.cache_dir = cache_dir |
| os.makedirs(cache_dir, exist_ok=True) |
| self._cache_path = os.path.join(cache_dir, "teacher_labels.json") |
| self._cache: Dict[str, Optional[bool]] = {} |
| self._dirty = False |
| self._load() |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def make_key(seq_name: str, frame_a: int, frame_b: int) -> str: |
| """Deterministic key; ORDER MATTERS. |
| |
| ``frame_a`` = old template frame (template[-2]) |
| ``frame_b`` = new candidate frame (template[-1]) |
| |
| The teacher is asked: *should we update FROM frame_a TO frame_b?* |
| This is a directional question, so the key preserves the order. |
| """ |
| fa = int(frame_a) |
| fb = int(frame_b) |
| return f"{seq_name}__{fa}__{fb}" |
|
|
| def get(self, seq_name: str, frame_a: int, frame_b: int) -> Optional[bool]: |
| """Return cached decision, or ``None`` on cache miss / teacher failure.""" |
| return self._cache.get(self.make_key(seq_name, frame_a, frame_b)) |
|
|
| def set(self, seq_name: str, frame_a: int, frame_b: int, decision: Optional[bool]): |
| """Store a decision. ``decision`` may be ``None`` (= teacher failed).""" |
| self._cache[self.make_key(seq_name, frame_a, frame_b)] = decision |
| self._dirty = True |
|
|
| def get_batch( |
| self, |
| seq_names: List[str], |
| frame_ids_a: List[int], |
| frame_ids_b: List[int], |
| ) -> List[Optional[bool]]: |
| """Look up a whole batch. Returns a list the same length as the inputs.""" |
| return [ |
| self.get(seq, int(fa), int(fb)) |
| for seq, fa, fb in zip(seq_names, frame_ids_a, frame_ids_b) |
| ] |
|
|
| def set_batch( |
| self, |
| seq_names: List[str], |
| frame_ids_a: List[int], |
| frame_ids_b: List[int], |
| decisions: List[Optional[bool]], |
| ): |
| """Store a whole batch.""" |
| for seq, fa, fb, dec in zip(seq_names, frame_ids_a, frame_ids_b, decisions): |
| self.set(seq, int(fa), int(fb), dec) |
|
|
| def hit_rate(self) -> float: |
| """Fraction of cache entries that are not ``None``.""" |
| if not self._cache: |
| return 0.0 |
| return sum(1 for v in self._cache.values() if v is not None) / len(self._cache) |
|
|
| |
| |
| |
|
|
| def save(self): |
| """Atomically write the cache to disk (if dirty).""" |
| if not self._dirty: |
| return |
| tmp_path = self._cache_path + ".tmp" |
| try: |
| with open(tmp_path, "w") as f: |
| fcntl.flock(f, fcntl.LOCK_EX) |
| json.dump(self._cache, f, indent=2, sort_keys=True) |
| fcntl.flock(f, fcntl.LOCK_UN) |
| os.rename(tmp_path, self._cache_path) |
| self._dirty = False |
| except (IOError, OSError): |
| |
| |
| pass |
|
|
| def _load(self): |
| if not os.path.exists(self._cache_path): |
| self._cache = {} |
| return |
| try: |
| with open(self._cache_path, "r") as f: |
| self._cache = json.load(f) |
| except (json.JSONDecodeError, IOError): |
| self._cache = {} |
|
|
| |
| |
| |
|
|
| def __len__(self) -> int: |
| return len(self._cache) |
|
|
| def __repr__(self) -> str: |
| return ( |
| f"TeacherLabelCache({len(self)} entries, " |
| f"hit_rate={self.hit_rate():.1%}, " |
| f"path={self._cache_path})" |
| ) |
|
|