trace / training /callbacks.py
Ayush
Update
4156f51
Raw
History Blame Contribute Delete
4.02 kB
"""
training/callbacks.py
Training callbacks for monitoring reward curves and inspecting
model outputs for reward hacking patterns.
Per the hackathon guide: "Periodic human inspection is still necessary."
"""
from __future__ import annotations
import json
import requests
from transformers import TrainerCallback, TrainerState, TrainerControl
class TraceRewardCallback(TrainerCallback):
"""
Samples model outputs every N steps and logs them for human inspection.
Flags suspicious patterns that might indicate reward hacking.
"""
HACK_PATTERNS = [
"ignore",
"bypass",
"hack",
"cheat",
"skip verification",
"fabricat",
]
def __init__(self, env_url: str, sample_every_n_steps: int = 20):
self.env_url = env_url
self.sample_every_n_steps = sample_every_n_steps
self._reward_history: list[float] = []
self._hack_alerts: list[dict] = []
def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
"""Track reward over time."""
if logs and "reward" in logs:
self._reward_history.append(logs["reward"])
if state.global_step % self.sample_every_n_steps == 0:
self._inspect_outputs(state, kwargs.get("model"), kwargs.get("tokenizer"))
def _inspect_outputs(self, state, model, tokenizer):
"""Generate sample outputs and inspect for hacking."""
if model is None or tokenizer is None:
return
# Reset a fresh episode for sampling
try:
resp = requests.post(
f"{self.env_url}/reset",
json={
"instruction": "Find all travel and ride receipts from Gmail in the last 10 days.",
"difficulty": "easy",
"available_sources": ["gmail"],
"ground_truth": {"answer": "Found ride receipts", "expected_numeric_target": 120.50},
},
timeout=5,
)
if resp.status_code != 200:
return
prompt_resp = requests.get(f"{self.env_url}/observation_prompt", timeout=5)
if prompt_resp.status_code != 200:
return
prompt = prompt_resp.json().get("prompt", "")
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with __import__("torch").no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated = generated[len(prompt):]
# Check for hack patterns
flags = [p for p in self.HACK_PATTERNS if p in generated.lower()]
if flags:
alert = {
"step": state.global_step,
"flags": flags,
"output_preview": generated[:200],
}
self._hack_alerts.append(alert)
print(f"\n⚠️ [ANTI-HACK ALERT] Step {state.global_step}")
print(f" Flags: {flags}")
print(f" Output: {generated[:200]}\n")
else:
print(f"\n✅ [SAMPLE] Step {state.global_step}: {generated[:150]}\n")
except Exception as e:
print(f"[Callback] Inspection failed: {e}")
def on_train_end(self, args, state, control, **kwargs):
"""Print final reward curve summary."""
if self._reward_history:
avg = sum(self._reward_history) / len(self._reward_history)
mx = max(self._reward_history)
print(f"\n[Trace] Training complete!")
print(f" Avg reward: {avg:.3f}")
print(f" Max reward: {mx:.3f}")
print(f" Total hack alerts: {len(self._hack_alerts)}")