WCNegentropy commited on
Commit
3658b97
·
verified ·
1 Parent(s): 86d8d49

Remove nested directory: BitTransformerLM/bit_transformer/safety.py

Browse files
BitTransformerLM/bit_transformer/safety.py DELETED
@@ -1,149 +0,0 @@
1
- import logging
2
- import time
3
- import torch
4
- from typing import Dict, Optional, Tuple
5
-
6
- from .model import BitTransformerLM
7
-
8
-
9
- class SafetyGate:
10
- """Exponential moving average safety gate with burn-in."""
11
-
12
- def __init__(
13
- self,
14
- *,
15
- c_floor: float = 0.3,
16
- s_floor: float = 0.5,
17
- decay: float = 0.9,
18
- burn_in: int = 10,
19
- ) -> None:
20
- self.c_floor = c_floor
21
- self.s_floor = s_floor
22
- self.decay = decay
23
- self.burn_in = burn_in
24
- self.step = 0
25
- self._c_ema: Optional[float] = None
26
- self._s_ema: Optional[float] = None
27
-
28
- def should_trigger(self, c_val: float, s_val: float) -> bool:
29
- """Update EMA scores and check if gating should trigger."""
30
-
31
- self.step += 1
32
- if self._c_ema is None:
33
- self._c_ema = c_val
34
- self._s_ema = s_val
35
- else:
36
- self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val
37
- self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val
38
- if self.step <= self.burn_in:
39
- return False
40
- return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor
41
-
42
-
43
- def hil_safe_inference(
44
- model: BitTransformerLM,
45
- bit_seq: torch.Tensor,
46
- c_floor: float = 0.3,
47
- s_floor: float = 0.5,
48
- *,
49
- causal: bool = True,
50
- strict: bool = True,
51
- gate: Optional[SafetyGate] = None,
52
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
53
- """Run inference with telemetry gating.
54
-
55
- Parameters
56
- ----------
57
- model:
58
- Model to run inference with.
59
- bit_seq:
60
- Input bit sequences.
61
- c_floor, s_floor:
62
- Minimum LZ complexity and symbiosis score required for safe output.
63
- causal:
64
- Whether to run the model in causal (autoregressive) mode. When ``False``
65
- the model performs full-context Diffusion LM inference.
66
- strict:
67
- If ``False`` the function returns model outputs even when the floors are
68
- not met instead of raising ``RuntimeError``.
69
- gate:
70
- Optional :class:`SafetyGate` that applies EMA smoothing and burn-in
71
- before enforcing the floors.
72
- """
73
- model.eval()
74
- with torch.no_grad():
75
- logits, telemetry = model(bit_seq, causal=causal)
76
- c_val = float(telemetry["lz_complexity_logits"].mean().item())
77
- s_val = float(telemetry["symbiosis_score"].mean().item())
78
- c_val = max(0.0, min(1.0, c_val))
79
- s_val = max(0.0, min(1.0, s_val))
80
- if gate is not None:
81
- triggered = gate.should_trigger(c_val, s_val)
82
- else:
83
- triggered = c_val <= c_floor or s_val <= s_floor
84
- if strict and triggered:
85
- raise RuntimeError(
86
- f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}"
87
- )
88
- return logits.argmax(-1), telemetry
89
-
90
-
91
- def demo_hil_safety() -> None:
92
- """Demonstrate gating on random bits."""
93
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
94
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
95
- try:
96
- out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0)
97
- print("Safe output bits:", out.squeeze(0).tolist())
98
- except RuntimeError as e:
99
- print("Gate triggered:", e)
100
-
101
-
102
- def safe_sample_with_retry(
103
- model: BitTransformerLM,
104
- bit_seq: torch.Tensor,
105
- c_floor: float = 0.3,
106
- s_floor: float = 0.5,
107
- *,
108
- causal: bool = True,
109
- max_retries: int = 3,
110
- backoff: float = 0.1,
111
- gate: Optional[SafetyGate] = None,
112
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
113
- """Run :func:`hil_safe_inference` with automatic retries.
114
-
115
- The helper retries failed safety checks by toggling diffusion mode and
116
- refreshing the input bits. An exponential backoff is applied between
117
- attempts and warnings are logged for each retry.
118
-
119
- Parameters
120
- ----------
121
- gate:
122
- Optional :class:`SafetyGate` instance shared across retries to apply
123
- EMA smoothing and burn-in.
124
-
125
- Returns
126
- -------
127
- Tuple[torch.Tensor, Dict[str, torch.Tensor]]
128
- The sampled bits and associated telemetry.
129
- """
130
-
131
- for attempt in range(max_retries):
132
- try:
133
- return hil_safe_inference(
134
- model,
135
- bit_seq,
136
- c_floor,
137
- s_floor,
138
- causal=causal,
139
- strict=True,
140
- gate=gate,
141
- )
142
- except RuntimeError as exc: # safety gate triggered
143
- logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc)
144
- if attempt >= max_retries - 1:
145
- raise
146
- time.sleep(backoff * (2 ** attempt))
147
- causal = False # retry in diffusion mode
148
- bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device)
149
-