Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import math | |
| import random | |
| from pathlib import Path | |
| from typing import Dict, List | |
| from train_env import TrainingEnv, default_action_catalog | |
| def softmax(xs: List[float]) -> List[float]: | |
| m = max(xs) | |
| exps = [math.exp(x - m) for x in xs] | |
| s = sum(exps) | |
| return [x / s for x in exps] | |
| def sample_index(probs: List[float]) -> int: | |
| r = random.random() | |
| c = 0.0 | |
| for i, p in enumerate(probs): | |
| c += p | |
| if r <= c: | |
| return i | |
| return len(probs) - 1 | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Policy-gradient training loop for the code-review environment") | |
| parser.add_argument("--episodes", type=int, default=120) | |
| parser.add_argument("--lr", type=float, default=0.08) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--log-dir", type=Path, default=Path("ppo_logs")) | |
| parser.add_argument("--max-steps", type=int, default=5) | |
| args = parser.parse_args() | |
| random.seed(args.seed) | |
| args.log_dir.mkdir(parents=True, exist_ok=True) | |
| env = TrainingEnv(max_steps=args.max_steps, seed=args.seed) | |
| catalog = default_action_catalog() | |
| # Start with a suboptimal policy and learn toward better action plans. | |
| logits: Dict[str, List[float]] = { | |
| "phase_1": [-1.0, 1.0], # prefer weak_comment initially | |
| "phase_2": [-1.0, 1.0], # prefer bad_fix initially | |
| "phase_3": [-0.5, 0.5], # slight approve bias initially | |
| } | |
| baseline_reward = 0.0 | |
| history = [] | |
| epsilon_start = 0.35 | |
| epsilon_end = 0.05 | |
| warmup_episodes = max(10, args.episodes // 6) | |
| for episode in range(1, args.episodes + 1): | |
| chosen = {} | |
| action_plan = [] | |
| for phase in ["phase_1", "phase_2", "phase_3"]: | |
| probs = softmax(logits[phase]) | |
| progress = episode / max(1, args.episodes) | |
| epsilon = epsilon_start + (epsilon_end - epsilon_start) * progress | |
| if episode <= warmup_episodes: | |
| # Warmup: deliberately weak choices to create a measurable learning baseline. | |
| idx = 1 if len(probs) > 1 else 0 | |
| elif random.random() < epsilon: | |
| idx = random.randrange(len(probs)) | |
| else: | |
| idx = sample_index(probs) | |
| chosen[phase] = (idx, probs[idx]) | |
| action_plan.append(catalog[phase][idx]) | |
| total_reward, task_score, steps = env.run_episode(action_plan) | |
| advantage = total_reward - baseline_reward | |
| baseline_reward = 0.9 * baseline_reward + 0.1 * total_reward | |
| for phase in ["phase_1", "phase_2", "phase_3"]: | |
| idx, prob = chosen[phase] | |
| grad = (1.0 - prob) | |
| logits[phase][idx] += args.lr * advantage * grad | |
| # Soft penalty to non-chosen actions to make learning sharper. | |
| for j in range(len(logits[phase])): | |
| if j != idx: | |
| logits[phase][j] -= args.lr * advantage * 0.15 | |
| history.append( | |
| { | |
| "episode": episode, | |
| "reward": round(total_reward, 4), | |
| "task_score": round(task_score, 4), | |
| "steps": steps, | |
| "baseline_reward": round(baseline_reward, 4), | |
| } | |
| ) | |
| metrics_path = args.log_dir / "train_metrics.csv" | |
| with metrics_path.open("w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=["episode", "reward", "task_score", "steps", "baseline_reward"]) | |
| writer.writeheader() | |
| writer.writerows(history) | |
| # Also emit a compact summary for README use. | |
| summary_path = args.log_dir / "summary.txt" | |
| first = history[:10] | |
| last = history[-10:] | |
| first_avg = sum(x["reward"] for x in first) / max(1, len(first)) | |
| last_avg = sum(x["reward"] for x in last) / max(1, len(last)) | |
| with summary_path.open("w", encoding="utf-8") as f: | |
| f.write(f"episodes={args.episodes}\n") | |
| f.write(f"avg_reward_first10={first_avg:.4f}\n") | |
| f.write(f"avg_reward_last10={last_avg:.4f}\n") | |
| f.write(f"improvement={last_avg - first_avg:.4f}\n") | |
| print(f"Training completed. Metrics: {metrics_path}") | |
| print(f"Summary: {summary_path}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |