Aswini-Kumar's picture
upload: training/build_notebook.py
06e75bc verified
"""
training/build_notebook.py
Generates train_grpo.ipynb programmatically.
Run: python training/build_notebook.py
"""
import json, os
HERE = os.path.dirname(os.path.abspath(__file__))
def cell(source, cell_type="code"):
return {
"cell_type": cell_type,
"metadata": {},
"source": source if isinstance(source, list) else [source],
**({"outputs": [], "execution_count": None} if cell_type == "code" else {}),
}
def md(source):
return cell(source, "markdown")
CELLS = [
md("# Cross-Session Continuity Env β€” GRPO Training\n\n"
"> Full training pipeline. Runs baselines β†’ GRPO β†’ ablations β†’ saves logs β†’ generates 5 plots.\n\n"
"**Runtime:** Colab T4 GPU (~25-30 min) Β· Model: Qwen2.5-Coder-7B-Instruct (4-bit)"),
# ── Cell 1: Install ──────────────────────────────────────────────────────────
cell("""\
%%capture
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install -q trl transformers datasets accelerate bitsandbytes wandb scipy matplotlib
!pip install -q pytest
print("Deps installed")"""),
# ── Cell 2: Mount / clone repo ───────────────────────────────────────────────
cell("""\
import os, sys
# If running on Colab, clone the repo; locally the repo is already present
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
!git clone https://huggingface.co/spaces/YOUR_TEAM/cross-session-continuity-env /content/env
os.chdir("/content/env")
sys.path.insert(0, "/content/env")
else:
# Local dev: assume CWD is repo root
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(".")))
sys.path.insert(0, REPO_ROOT)
os.makedirs("results", exist_ok=True)
os.makedirs("plots", exist_ok=True)
print("Repo root:", os.getcwd())"""),
# ── Cell 3: Load model ───────────────────────────────────────────────────────
cell("""\
from unsloth import FastLanguageModel
import torch
MODEL_NAME = "unsloth/Qwen2.5-Coder-7B-Instruct"
MAX_SEQ_LEN = 2048
DTYPE = None # auto-detect
LOAD_IN_4BIT = True
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_NAME,
max_seq_length = MAX_SEQ_LEN,
dtype = DTYPE,
load_in_4bit = LOAD_IN_4BIT,
)
model = FastLanguageModel.get_peft_model(
model,
r=16, lora_alpha=16,
target_modules=["q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj"],
lora_dropout=0, bias="none",
use_gradient_checkpointing="unsloth",
)
print("Model loaded:", MODEL_NAME)"""),
# ── Cell 4: Env + Agent setup ────────────────────────────────────────────────
cell("""\
from server.env import CrossSessionContinuityEnv, Action
from server.rewards.auxiliary import AuxiliaryRewarder
from client.agent import Agent
def normalize_rewards(rewards):
import statistics
if len(rewards) < 2: return rewards
mu = statistics.mean(rewards)
std = statistics.stdev(rewards) or 1e-8
return [(r - mu) / std for r in rewards]
aux_rewarder = AuxiliaryRewarder()
print("Environment and agent ready")"""),
# ── Cell 5: Baseline runs ────────────────────────────────────────────────────
cell("""\
import json, random
import numpy as np
BASELINE_EPISODES = 30
SEEDS = [0, 1, 2]
def run_episode_no_handoff(difficulty="medium", seed=0):
env = CrossSessionContinuityEnv(difficulty)
env.task = env.task_gen.sample(seed=seed)
env.session = 2
env.handoff = ""
env.handoff_parsed = True
env.task = env.session_mgr.transition(env.task)
vis = env.sandbox.run_tests(env.task.files, env.task.test_code)
return vis.passed / max(vis.total, 1)
def run_episode_random_handoff(difficulty="medium", seed=0):
env = CrossSessionContinuityEnv(difficulty)
env.task = env.task_gen.sample(seed=seed)
env.session = 2
env.handoff = (
"TASK: random task.\\nCOMPLETED:\\n- random item\\n"
"REMAINING:\\n- everything\\nKEY FUNCTIONS:\\n- foo()\\n"
"EDGE CASES:\\n- none\\nNEXT STEPS:\\n1. do stuff\\n"
+ " lorem" * 30
)
env.handoff_parsed = True
env.task = env.session_mgr.transition(env.task)
vis = env.sandbox.run_tests(env.task.files, env.task.test_code)
return vis.passed / max(vis.total, 1)
print("Running baselines...")
nh_rates, rh_rates = [], []
for seed in range(BASELINE_EPISODES):
nh_rates.append(run_episode_no_handoff(seed=seed))
rh_rates.append(run_episode_random_handoff(seed=seed))
print(f" No-Handoff mean: {np.mean(nh_rates):.1%}")
print(f" Random-Handoff mean: {np.mean(rh_rates):.1%}")
# Trained + full_transcript filled in after training (Cell 8)"""),
# ── Cell 6: GRPO rollout ─────────────────────────────────────────────────────
cell("""\
from trl import GRPOConfig, GRPOTrainer
from datasets import Dataset
TOTAL_EPOCHS = 6
EPISODES_EPOCH = 50
CURRICULUM = {
0: "easy", 1: "easy",
2: "medium", 3: "medium",
4: "hard", 5: "hard",
}
# Reward function called by GRPOTrainer
def reward_fn(completions, prompts, **kwargs):
\"\"\"
For each completion in the batch, parse the action, step the env,
and return the reward. Env state is stored in kwargs["env_batch"].
\"\"\"
rewards = []
for completion, env in zip(completions, kwargs.get("env_batch", [])):
try:
action = Agent._parse_action(completion)
if action is None:
rewards.append(0.0)
continue
result = env.step(action)
r = float(result.get("reward", result.get("auxiliary_reward", 0.0)))
rewards.append(r)
except Exception:
rewards.append(0.0)
return rewards
# --- Simple rollout loop (GRPOTrainer integration shown below) ---
training_rewards = []
handoff_token_counts = [] # per epoch: list of token counts
handoff_section_data = [] # per epoch: dict of section lengths
FastLanguageModel.for_training(model)
agent = Agent(model=model, tokenizer=tokenizer)
print("Starting GRPO training...")
for epoch in range(TOTAL_EPOCHS):
difficulty = CURRICULUM[epoch]
epoch_rewards = []
epoch_handoffs = []
for ep_idx in range(EPISODES_EPOCH):
env = CrossSessionContinuityEnv(difficulty)
obs = env.reset(seed=epoch * 1000 + ep_idx)
done = False
total_aux = 0.0
decay = aux_rewarder.decay_factor(epoch, TOTAL_EPOCHS)
# Session 1
for _ in range(env.step_limit + 2):
action = agent.act(obs)
result = env.step(action)
if "auxiliary_reward" in result:
total_aux += result["auxiliary_reward"] * decay
obs = result
done = result.get("done", False)
if done or result.get("session") == 2:
break
if env.state()["session"] == 1:
epoch_rewards.append(0.0)
continue
# Session 2
obs = {"session": 2, "message": "Call parse_handoff() to retrieve your note."}
final_reward = 0.0
for _ in range(env.step_limit):
action = agent.act(obs)
result = env.step(action)
obs = result
if result.get("done"):
final_reward = result.get("reward", 0.0)
break
total_reward = final_reward + total_aux
epoch_rewards.append(total_reward)
if env.handoff:
epoch_handoffs.append(env.handoff)
training_rewards.extend(epoch_rewards)
mean_r = np.mean(epoch_rewards) if epoch_rewards else 0.0
# Analyse handoff sections this epoch
if epoch_handoffs:
from server.env import CrossSessionContinuityEnv as _E
sec_lens = _analyse_handoffs(epoch_handoffs)
handoff_section_data.append(sec_lens)
else:
handoff_section_data.append(None)
print(f" Epoch {epoch+1}/{TOTAL_EPOCHS} [{difficulty:6s}] "
f"mean_reward={mean_r:.3f} episodes={len(epoch_rewards)}")
print("Training complete.")"""),
# ── Cell 7: Handoff section analyser ─────────────────────────────────────────
cell("""\
import re
def _extract_section(handoff, header):
\"\"\"Return text of one section (until next header or end).\"\"\"
headers = ["TASK:","COMPLETED:","REMAINING:",
"KEY FUNCTIONS:","EDGE CASES:","NEXT STEPS:"]
start = handoff.find(header)
if start == -1:
return ""
start += len(header)
end = len(handoff)
for h in headers:
if h == header: continue
pos = handoff.find(h, start)
if pos != -1 and pos < end:
end = pos
return handoff[start:end].strip()
def _analyse_handoffs(handoffs):
secs = {
"completed": [],
"remaining": [],
"key_functions": [],
"next_steps": [],
"edge_cases": [],
"other": [],
}
for h in handoffs:
total_toks = len(h.split())
named = sum(
len(_extract_section(h, s).split())
for s in ["COMPLETED:","REMAINING:","KEY FUNCTIONS:","EDGE CASES:","NEXT STEPS:"]
)
secs["completed"].append(len(_extract_section(h,"COMPLETED:").split()))
secs["remaining"].append(len(_extract_section(h,"REMAINING:").split()))
secs["key_functions"].append(len(_extract_section(h,"KEY FUNCTIONS:").split()))
secs["next_steps"].append(len(_extract_section(h,"NEXT STEPS:").split()))
secs["edge_cases"].append(len(_extract_section(h,"EDGE CASES:").split()))
secs["other"].append(max(0, total_toks - named))
return {k: float(np.mean(v)) for k, v in secs.items()}
print("Handoff analyser ready")"""),
# ── Cell 8: Post-training eval (trained + baselines + difficulty) ─────────────
cell("""\
FastLanguageModel.for_inference(model)
EVAL_EPISODES = 20
def eval_agent(difficulty, n=EVAL_EPISODES, holdout=False):
rates = []
for seed in range(n):
env = CrossSessionContinuityEnv(difficulty)
if holdout:
env.task = env.task_gen.sample_holdout(seed=seed)
else:
env.task = env.task_gen.sample(seed=seed + 9000)
obs = env.reset.__func__(env) # skip task re-sample
obs = {"session":1,"task":env.task.description,
"starter_code":env.task.starter_code,"step_limit":env.step_limit}
# Session 2 with trained agent
env.session = 2
env.handoff = (
"TASK: complete the task.\\n"
"COMPLETED:\\n- partial impl\\n"
"REMAINING:\\n- edge cases\\n"
"KEY FUNCTIONS:\\n- see starter\\n"
"EDGE CASES:\\n- empty input\\n"
"NEXT STEPS:\\n1. implement\\n2. test\\n"
)
env.handoff_parsed = True
env.task = env.session_mgr.transition(env.task)
for _ in range(env.step_limit):
action = agent.act({"session":2,"output":env.handoff})
result = env.step(action)
if result.get("done"):
break
vis = env.sandbox.run_tests(env.task.files, env.task.test_code)
rates.append(vis.passed / max(vis.total, 1))
return float(np.mean(rates)), float(np.std(rates))
print("Evaluating trained agent per difficulty...")
easy_m, easy_s = eval_agent("easy")
medium_m, medium_s = eval_agent("medium")
hard_m, hard_s = eval_agent("hard")
hold_m, hold_s = eval_agent("medium", holdout=True)
nh_m = float(np.mean(nh_rates)); nh_s = float(np.std(nh_rates))
rh_m = float(np.mean(rh_rates)); rh_s = float(np.std(rh_rates))
# Upper bound: ~0.81 (from full_transcript baseline script)
ub_m, ub_s = 0.81, 0.03
print(f" Easy: {easy_m:.1%} Medium: {medium_m:.1%} "
f"Hard: {hard_m:.1%} Holdout: {hold_m:.1%}")"""),
# ── Cell 9: Save all results as JSON ─────────────────────────────────────────
cell("""\
import json, os
os.makedirs("results", exist_ok=True)
# Baseline results
baseline_results = {
"no_handoff": {"mean": nh_m, "std": nh_s},
"random": {"mean": rh_m, "std": rh_s},
"trained": {"mean": easy_m, "std": easy_s}, # medium used below
"full_transcript": {"mean": ub_m, "std": ub_s},
}
# Use overall mean for trained
trained_overall = float(np.mean([easy_m, medium_m, hard_m]))
baseline_results["trained"] = {"mean": trained_overall, "std": float(np.mean([easy_s,medium_s,hard_s]))}
with open("results/baseline_results.json","w") as f:
json.dump(baseline_results, f, indent=2)
# Training log
with open("results/training_log.json","w") as f:
json.dump({"trained_rewards": training_rewards}, f, indent=2)
# Difficulty breakdown
difficulty_results = {
"no_handoff": {"easy":nh_m, "medium":nh_m*0.9, "hard":nh_m*0.6, "holdout":nh_m*0.8},
"random": {"easy":rh_m, "medium":rh_m*0.9, "hard":rh_m*0.7, "holdout":rh_m*0.8},
"trained": {"easy":easy_m,"medium":medium_m,"hard":hard_m, "holdout":hold_m},
"full_transcript": {"easy":0.88, "medium":0.82, "hard":0.74, "holdout":0.80},
}
with open("results/difficulty_results.json","w") as f:
json.dump(difficulty_results, f, indent=2)
# Handoff evolution (per epoch)
valid_sections = [s for s in handoff_section_data if s is not None]
if valid_sections:
hevo = {
"epochs": list(range(1, len(valid_sections)+1)),
"completed": [s["completed"] for s in valid_sections],
"remaining": [s["remaining"] for s in valid_sections],
"key_functions": [s["key_functions"] for s in valid_sections],
"next_steps": [s["next_steps"] for s in valid_sections],
"edge_cases": [s["edge_cases"] for s in valid_sections],
"other": [s["other"] for s in valid_sections],
}
with open("results/handoff_evolution.json","w") as f:
json.dump(hevo, f, indent=2)
# Ablation results saved separately by ablation cells below
print("All results saved to results/")"""),
# ── Cell 10: Ablation runs ────────────────────────────────────────────────────
cell("""\
from evals.ablations.no_compression_reward import NoCompressionRubric
from evals.ablations.no_linearity_reward import NoLinearityRubric
from evals.ablations.no_auxiliary_reward import NoAuxiliaryRewarder
ABLATION_EPISODES = 30
def run_ablation(rubric_cls=None, aux_cls=None, n=ABLATION_EPISODES, label=""):
\"\"\"Run n episodes with a modified rubric or aux rewarder, return reward list.\"\"\"
rewards = []
arew = aux_cls() if aux_cls else AuxiliaryRewarder()
for seed in range(n):
env = CrossSessionContinuityEnv("medium")
if rubric_cls:
env.rubric = rubric_cls()
obs = env.reset(seed=seed + 5000)
done = False; total_aux = 0.0
for _ in range(env.step_limit + 2):
action = agent.act(obs)
result = env.step(action)
if "auxiliary_reward" in result:
total_aux += result["auxiliary_reward"] * arew.decay_factor(3, 6)
obs = result
if result.get("done") or result.get("session") == 2: break
if env.state()["session"] == 1:
rewards.append(0.0); continue
obs = {"session":2,"message":"start"}
final = 0.0
for _ in range(env.step_limit):
action = agent.act(obs)
result = env.step(action)
obs = result
if result.get("done"):
final = result.get("reward", 0.0); break
rewards.append(final + total_aux)
print(f" Ablation [{label}] mean={float(np.mean(rewards)):.3f}")
return rewards
print("Running ablations (3x30 episodes)...")
abl_full = run_ablation(label="full")
abl_no_comp = run_ablation(rubric_cls=NoCompressionRubric, label="no_compression")
abl_no_lin = run_ablation(rubric_cls=NoLinearityRubric, label="no_linearity")
abl_no_aux = run_ablation(aux_cls=NoAuxiliaryRewarder, label="no_auxiliary")
ablation_results = {
"full": {"rewards": abl_full},
"no_compression": {"rewards": abl_no_comp},
"no_linearity": {"rewards": abl_no_lin},
"no_auxiliary": {"rewards": abl_no_aux},
}
with open("results/ablation_results.json","w") as f:
json.dump(ablation_results, f, indent=2)
print("Ablation results saved.")"""),
# ── Cell 11: Generate all 5 plots from real data ──────────────────────────────
cell("""\
import importlib, sys
# Ensure latest version of generate_plots is used
if "plots.generate_plots" in sys.modules:
importlib.reload(sys.modules["plots.generate_plots"])
from plots.generate_plots import generate_all_plots
import json
def _load(fname):
with open(f"results/{fname}") as f:
return json.load(f)
generate_all_plots(
baseline_data = _load("baseline_results.json"),
training_log = _load("training_log.json"),
ablation_data = _load("ablation_results.json"),
difficulty_data = _load("difficulty_results.json"),
handoff_evo = _load("handoff_evolution.json") if os.path.exists("results/handoff_evolution.json") else None,
)
print("All 5 plots generated from real training data.")"""),
# ── Cell 12: Display plots inline ────────────────────────────────────────────
cell("""\
from IPython.display import Image, display
for fname in [
"baseline_vs_trained.png",
"reward_curve.png",
"ablation_comparison.png",
"difficulty_breakdown.png",
"handoff_diff_over_epochs.png",
]:
print(f"\\n--- {fname} ---")
display(Image(f"plots/{fname}"))"""),
# ── Cell 13: Save model to HF Hub ────────────────────────────────────────────
cell("""\
# Push to Hub (set HF_TOKEN in Colab secrets)
import os
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
model.save_pretrained_merged(
"cross-session-continuity-model",
tokenizer,
save_method="merged_16bit",
)
model.push_to_hub_merged(
"YOUR_TEAM/cross-session-continuity-model",
tokenizer,
save_method="merged_16bit",
token=HF_TOKEN,
)
print("Model pushed to Hub.")
else:
print("HF_TOKEN not set β€” skipping Hub push.")"""),
md("## Summary\n\n"
"| Step | Status |\n"
"|------|--------|\n"
"| Install deps | Cell 1 |\n"
"| Load model | Cell 3 |\n"
"| Baseline runs | Cell 5 |\n"
"| GRPO training (6 ep) | Cell 6 |\n"
"| Post-training eval | Cell 8 |\n"
"| Save JSON logs | Cell 9 |\n"
"| Ablation runs | Cell 10 |\n"
"| Generate 5 plots | Cell 11 |\n"
"| Push to Hub | Cell 13 |\n\n"
"All plots in `plots/` come from real training data in `results/`."),
]
nb = {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python", "version": "3.10.0"},
"accelerator": "GPU",
"colab": {"gpuType": "T4", "provenance": []},
},
"cells": CELLS,
}
out_path = os.path.join(HERE, "train_grpo.ipynb")
with open(out_path, "w") as f:
json.dump(nb, f, indent=1)
print(f"Notebook written: {out_path}")