ATCTrack-VLM / lib /train /data /teacher_label_cache.py
SunXiang2025's picture
Update: two-stage training, per-channel FiLM gate, cosine scheduler, 9B config
b3f019f verified
"""
Persistent cache for teacher update-decisions.
Used for two-stage training:
Stage 1 (label generation) – teacher API is called once per unique frame pair,
and the result is saved to disk.
Stage 2 (training) – the cached decision is read directly, avoiding any
online teacher API call during the forward pass.
Keys are (seq_name, frame_id_A, frame_id_B) – the two template-candidate frames
that Qwen compares. Values are Python bools, or ``null`` (JSON) when the teacher
failed for that pair.
The cache is persisted as a single JSON file. Because the file may be written
by multiple DDP ranks, all writes use an atomic-rename pattern with fcntl
locking (best-effort on the local filesystem).
"""
from __future__ import annotations
import fcntl
import json
import os
from typing import Dict, List, Optional, Tuple
class TeacherLabelCache:
"""Thread/process-safe persistent cache for teacher update decisions.
Usage::
cache = TeacherLabelCache("./output/teacher_cache")
dec = cache.get("airplane-1", 120, 150) # → True / False / None
cache.set("airplane-1", 120, 150, True)
cache.save()
"""
def __init__(self, cache_dir: str):
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
self._cache_path = os.path.join(cache_dir, "teacher_labels.json")
self._cache: Dict[str, Optional[bool]] = {}
self._dirty = False
self._load()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@staticmethod
def make_key(seq_name: str, frame_a: int, frame_b: int) -> str:
"""Deterministic key; ORDER MATTERS.
``frame_a`` = old template frame (template[-2])
``frame_b`` = new candidate frame (template[-1])
The teacher is asked: *should we update FROM frame_a TO frame_b?*
This is a directional question, so the key preserves the order.
"""
fa = int(frame_a)
fb = int(frame_b)
return f"{seq_name}__{fa}__{fb}"
def get(self, seq_name: str, frame_a: int, frame_b: int) -> Optional[bool]:
"""Return cached decision, or ``None`` on cache miss / teacher failure."""
return self._cache.get(self.make_key(seq_name, frame_a, frame_b))
def set(self, seq_name: str, frame_a: int, frame_b: int, decision: Optional[bool]):
"""Store a decision. ``decision`` may be ``None`` (= teacher failed)."""
self._cache[self.make_key(seq_name, frame_a, frame_b)] = decision
self._dirty = True
def get_batch(
self,
seq_names: List[str],
frame_ids_a: List[int],
frame_ids_b: List[int],
) -> List[Optional[bool]]:
"""Look up a whole batch. Returns a list the same length as the inputs."""
return [
self.get(seq, int(fa), int(fb))
for seq, fa, fb in zip(seq_names, frame_ids_a, frame_ids_b)
]
def set_batch(
self,
seq_names: List[str],
frame_ids_a: List[int],
frame_ids_b: List[int],
decisions: List[Optional[bool]],
):
"""Store a whole batch."""
for seq, fa, fb, dec in zip(seq_names, frame_ids_a, frame_ids_b, decisions):
self.set(seq, int(fa), int(fb), dec)
def hit_rate(self) -> float:
"""Fraction of cache entries that are not ``None``."""
if not self._cache:
return 0.0
return sum(1 for v in self._cache.values() if v is not None) / len(self._cache)
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def save(self):
"""Atomically write the cache to disk (if dirty)."""
if not self._dirty:
return
tmp_path = self._cache_path + ".tmp"
try:
with open(tmp_path, "w") as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(self._cache, f, indent=2, sort_keys=True)
fcntl.flock(f, fcntl.LOCK_UN)
os.rename(tmp_path, self._cache_path)
self._dirty = False
except (IOError, OSError):
# Non-critical – the in-memory cache is still valid; disk write
# will be retried on the next ``save()``.
pass
def _load(self):
if not os.path.exists(self._cache_path):
self._cache = {}
return
try:
with open(self._cache_path, "r") as f:
self._cache = json.load(f)
except (json.JSONDecodeError, IOError):
self._cache = {}
# ------------------------------------------------------------------
# Info
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self._cache)
def __repr__(self) -> str:
return (
f"TeacherLabelCache({len(self)} entries, "
f"hit_rate={self.hit_rate():.1%}, "
f"path={self._cache_path})"
)