replicalab / notebooks /cell_15_fix.py
maxxie114's picture
Initial HF Spaces deployment
80d8c84
# ── 10. Post-training evaluation ──────────────────────────────────────────────
# transformers >= 5.0 expects content as a list of blocks, not a plain string.
import torch
FastLanguageModel.for_inference(model)
_eval_client = EnvClient(ENV_BASE_URL)
trained_rewards = []
successes = 0
def _to_messages(obs: dict) -> list[dict]:
"""Build chat messages in multimodal-block format required by transformers 5.x."""
return [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "text", "text": obs_to_prompt(obs)}]},
]
model.eval()
with torch.no_grad():
for seed in EVAL_SEEDS:
try:
reset_resp = _eval_client.reset(seed=seed, scenario="math_reasoning", difficulty="easy")
obs = reset_resp.get("observation")
messages = _to_messages(obs)
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
out = model.generate(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)
gen_ids = out[0][inputs.shape[-1]:]
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
action = parse_action(text)
result = _eval_client.step(action)
r = float(result.get("reward", 0.0))
info = result.get("info") or {}
if isinstance(info, dict) and info.get("agreement_reached"):
r += 0.5
successes += 1
trained_rewards.append(r)
print(f"[eval] seed={seed} action={action['action_type']:20s} reward={r:+.3f}")
except Exception as e:
import traceback; traceback.print_exc()
trained_rewards.append(0.0)
break
trained_avg = sum(trained_rewards) / max(len(trained_rewards), 1)
trained_rate = successes / max(len(EVAL_SEEDS), 1)
print(f"\n[AFTER training] avg_reward={trained_avg:.4f} success_rate={trained_rate:.2%} ({successes}/{len(EVAL_SEEDS)} episodes)")
print("\n=== RESULTS ===")
print(f"Baseline (training curve step 5) : reward = -0.0700")
print(f"Trained avg reward : {trained_avg:.4f} success rate: {trained_rate:.2%}")
print(f"GRPO improvement : -0.07 β†’ peak +0.15 over 200 steps (see reward curve above)")