File size: 3,091 Bytes
19f7f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Stage-4 entry safety: Pool-B variance gate + r_cross warmup schedule.

The brief's bootstrapping constraint:

  - If `Var[r_code | context(Ο„_1)]` is too high (the code agent itself is
    unstable on this task class), then `r_cross` is mostly noise and
    Stage 4 will inject that noise into Phase-1 gradients.  Block entry
    until variance falls below a threshold.

  - Even after the gate opens, ramp `r_cross_weight` from 0 β†’ 1 over
    `warmup_steps` so Phase-1 learning isn't suddenly knocked off-track.

This module is small and pure β€” the curriculum runner queries it on
each Stage-4 step.
"""

from __future__ import annotations

import statistics
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Optional


@dataclass
class VarianceGate:
    """Tracks a moving window of `r_code` values per task and gates Stage 4."""

    window:                int      = 64
    max_acceptable_stdev:  float    = 0.15     # standard deviation of r_code
    min_samples:           int      = 16
    history: Dict[str, Deque[float]] = field(default_factory=dict)

    def record(self, task_name: str, r_code: float) -> None:
        if task_name not in self.history:
            self.history[task_name] = deque(maxlen=self.window)
        self.history[task_name].append(float(r_code))

    def stdev(self, task_name: str) -> float:
        h = self.history.get(task_name) or []
        if len(h) < 2:
            return float("inf")
        return statistics.stdev(h)

    def is_open_for(self, task_name: str) -> bool:
        h = self.history.get(task_name) or []
        if len(h) < self.min_samples:
            return False
        return self.stdev(task_name) <= self.max_acceptable_stdev

    def open_tasks(self) -> List[str]:
        return [t for t in self.history if self.is_open_for(t)]

    def status(self) -> Dict[str, Dict[str, float]]:
        """Diagnostic snapshot for logging."""
        return {
            t: {
                "n":     len(h),
                "stdev": round(self.stdev(t), 4),
                "open":  self.is_open_for(t),
            }
            for t, h in self.history.items()
        }


# ──────────────────────────────────────────────────────────────────────
# r_cross warmup schedule
# ──────────────────────────────────────────────────────────────────────


@dataclass
class RCrossWarmup:
    """
    Linear warmup of the r_cross coefficient over the first
    `warmup_steps` Stage-4 optimizer steps.

        weight(step) = min(1.0, step / warmup_steps)
    """
    warmup_steps: int   = 500
    cap:          float = 1.0

    def weight(self, step: int) -> float:
        if self.warmup_steps <= 0:
            return self.cap
        return min(self.cap, step / self.warmup_steps)