File size: 5,156 Bytes
b3f019f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Persistent cache for teacher update-decisions.

Used for two-stage training:
  Stage 1 (label generation) – teacher API is called once per unique frame pair,
      and the result is saved to disk.
  Stage 2 (training) – the cached decision is read directly, avoiding any
      online teacher API call during the forward pass.

Keys are (seq_name, frame_id_A, frame_id_B) – the two template-candidate frames
that Qwen compares.  Values are Python bools, or ``null`` (JSON) when the teacher
failed for that pair.

The cache is persisted as a single JSON file.  Because the file may be written
by multiple DDP ranks, all writes use an atomic-rename pattern with fcntl
locking (best-effort on the local filesystem).
"""

from __future__ import annotations

import fcntl
import json
import os
from typing import Dict, List, Optional, Tuple


class TeacherLabelCache:
    """Thread/process-safe persistent cache for teacher update decisions.

    Usage::

        cache = TeacherLabelCache("./output/teacher_cache")
        dec = cache.get("airplane-1", 120, 150)   # → True / False / None
        cache.set("airplane-1", 120, 150, True)
        cache.save()
    """

    def __init__(self, cache_dir: str):
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        self._cache_path = os.path.join(cache_dir, "teacher_labels.json")
        self._cache: Dict[str, Optional[bool]] = {}
        self._dirty = False
        self._load()

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    @staticmethod
    def make_key(seq_name: str, frame_a: int, frame_b: int) -> str:
        """Deterministic key; ORDER MATTERS.

        ``frame_a`` = old template frame (template[-2])
        ``frame_b`` = new candidate frame (template[-1])

        The teacher is asked: *should we update FROM frame_a TO frame_b?*
        This is a directional question, so the key preserves the order.
        """
        fa = int(frame_a)
        fb = int(frame_b)
        return f"{seq_name}__{fa}__{fb}"

    def get(self, seq_name: str, frame_a: int, frame_b: int) -> Optional[bool]:
        """Return cached decision, or ``None`` on cache miss / teacher failure."""
        return self._cache.get(self.make_key(seq_name, frame_a, frame_b))

    def set(self, seq_name: str, frame_a: int, frame_b: int, decision: Optional[bool]):
        """Store a decision.  ``decision`` may be ``None`` (= teacher failed)."""
        self._cache[self.make_key(seq_name, frame_a, frame_b)] = decision
        self._dirty = True

    def get_batch(
        self,
        seq_names: List[str],
        frame_ids_a: List[int],
        frame_ids_b: List[int],
    ) -> List[Optional[bool]]:
        """Look up a whole batch.  Returns a list the same length as the inputs."""
        return [
            self.get(seq, int(fa), int(fb))
            for seq, fa, fb in zip(seq_names, frame_ids_a, frame_ids_b)
        ]

    def set_batch(
        self,
        seq_names: List[str],
        frame_ids_a: List[int],
        frame_ids_b: List[int],
        decisions: List[Optional[bool]],
    ):
        """Store a whole batch."""
        for seq, fa, fb, dec in zip(seq_names, frame_ids_a, frame_ids_b, decisions):
            self.set(seq, int(fa), int(fb), dec)

    def hit_rate(self) -> float:
        """Fraction of cache entries that are not ``None``."""
        if not self._cache:
            return 0.0
        return sum(1 for v in self._cache.values() if v is not None) / len(self._cache)

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self):
        """Atomically write the cache to disk (if dirty)."""
        if not self._dirty:
            return
        tmp_path = self._cache_path + ".tmp"
        try:
            with open(tmp_path, "w") as f:
                fcntl.flock(f, fcntl.LOCK_EX)
                json.dump(self._cache, f, indent=2, sort_keys=True)
                fcntl.flock(f, fcntl.LOCK_UN)
            os.rename(tmp_path, self._cache_path)
            self._dirty = False
        except (IOError, OSError):
            # Non-critical – the in-memory cache is still valid; disk write
            # will be retried on the next ``save()``.
            pass

    def _load(self):
        if not os.path.exists(self._cache_path):
            self._cache = {}
            return
        try:
            with open(self._cache_path, "r") as f:
                self._cache = json.load(f)
        except (json.JSONDecodeError, IOError):
            self._cache = {}

    # ------------------------------------------------------------------
    # Info
    # ------------------------------------------------------------------

    def __len__(self) -> int:
        return len(self._cache)

    def __repr__(self) -> str:
        return (
            f"TeacherLabelCache({len(self)} entries, "
            f"hit_rate={self.hit_rate():.1%}, "
            f"path={self._cache_path})"
        )