File size: 5,156 Bytes
b3f019f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """
Persistent cache for teacher update-decisions.
Used for two-stage training:
Stage 1 (label generation) – teacher API is called once per unique frame pair,
and the result is saved to disk.
Stage 2 (training) – the cached decision is read directly, avoiding any
online teacher API call during the forward pass.
Keys are (seq_name, frame_id_A, frame_id_B) – the two template-candidate frames
that Qwen compares. Values are Python bools, or ``null`` (JSON) when the teacher
failed for that pair.
The cache is persisted as a single JSON file. Because the file may be written
by multiple DDP ranks, all writes use an atomic-rename pattern with fcntl
locking (best-effort on the local filesystem).
"""
from __future__ import annotations
import fcntl
import json
import os
from typing import Dict, List, Optional, Tuple
class TeacherLabelCache:
"""Thread/process-safe persistent cache for teacher update decisions.
Usage::
cache = TeacherLabelCache("./output/teacher_cache")
dec = cache.get("airplane-1", 120, 150) # → True / False / None
cache.set("airplane-1", 120, 150, True)
cache.save()
"""
def __init__(self, cache_dir: str):
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
self._cache_path = os.path.join(cache_dir, "teacher_labels.json")
self._cache: Dict[str, Optional[bool]] = {}
self._dirty = False
self._load()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@staticmethod
def make_key(seq_name: str, frame_a: int, frame_b: int) -> str:
"""Deterministic key; ORDER MATTERS.
``frame_a`` = old template frame (template[-2])
``frame_b`` = new candidate frame (template[-1])
The teacher is asked: *should we update FROM frame_a TO frame_b?*
This is a directional question, so the key preserves the order.
"""
fa = int(frame_a)
fb = int(frame_b)
return f"{seq_name}__{fa}__{fb}"
def get(self, seq_name: str, frame_a: int, frame_b: int) -> Optional[bool]:
"""Return cached decision, or ``None`` on cache miss / teacher failure."""
return self._cache.get(self.make_key(seq_name, frame_a, frame_b))
def set(self, seq_name: str, frame_a: int, frame_b: int, decision: Optional[bool]):
"""Store a decision. ``decision`` may be ``None`` (= teacher failed)."""
self._cache[self.make_key(seq_name, frame_a, frame_b)] = decision
self._dirty = True
def get_batch(
self,
seq_names: List[str],
frame_ids_a: List[int],
frame_ids_b: List[int],
) -> List[Optional[bool]]:
"""Look up a whole batch. Returns a list the same length as the inputs."""
return [
self.get(seq, int(fa), int(fb))
for seq, fa, fb in zip(seq_names, frame_ids_a, frame_ids_b)
]
def set_batch(
self,
seq_names: List[str],
frame_ids_a: List[int],
frame_ids_b: List[int],
decisions: List[Optional[bool]],
):
"""Store a whole batch."""
for seq, fa, fb, dec in zip(seq_names, frame_ids_a, frame_ids_b, decisions):
self.set(seq, int(fa), int(fb), dec)
def hit_rate(self) -> float:
"""Fraction of cache entries that are not ``None``."""
if not self._cache:
return 0.0
return sum(1 for v in self._cache.values() if v is not None) / len(self._cache)
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def save(self):
"""Atomically write the cache to disk (if dirty)."""
if not self._dirty:
return
tmp_path = self._cache_path + ".tmp"
try:
with open(tmp_path, "w") as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(self._cache, f, indent=2, sort_keys=True)
fcntl.flock(f, fcntl.LOCK_UN)
os.rename(tmp_path, self._cache_path)
self._dirty = False
except (IOError, OSError):
# Non-critical – the in-memory cache is still valid; disk write
# will be retried on the next ``save()``.
pass
def _load(self):
if not os.path.exists(self._cache_path):
self._cache = {}
return
try:
with open(self._cache_path, "r") as f:
self._cache = json.load(f)
except (json.JSONDecodeError, IOError):
self._cache = {}
# ------------------------------------------------------------------
# Info
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self._cache)
def __repr__(self) -> str:
return (
f"TeacherLabelCache({len(self)} entries, "
f"hit_rate={self.hit_rate():.1%}, "
f"path={self._cache_path})"
)
|