File size: 4,379 Bytes
1939cbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import math
import random
from pathlib import Path
from typing import Dict, List

from train_env import TrainingEnv, default_action_catalog


def softmax(xs: List[float]) -> List[float]:
    m = max(xs)
    exps = [math.exp(x - m) for x in xs]
    s = sum(exps)
    return [x / s for x in exps]


def sample_index(probs: List[float]) -> int:
    r = random.random()
    c = 0.0
    for i, p in enumerate(probs):
        c += p
        if r <= c:
            return i
    return len(probs) - 1


def main() -> int:
    parser = argparse.ArgumentParser(description="Policy-gradient training loop for the code-review environment")
    parser.add_argument("--episodes", type=int, default=120)
    parser.add_argument("--lr", type=float, default=0.08)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--log-dir", type=Path, default=Path("ppo_logs"))
    parser.add_argument("--max-steps", type=int, default=5)
    args = parser.parse_args()

    random.seed(args.seed)
    args.log_dir.mkdir(parents=True, exist_ok=True)

    env = TrainingEnv(max_steps=args.max_steps, seed=args.seed)
    catalog = default_action_catalog()

    # Start with a suboptimal policy and learn toward better action plans.
    logits: Dict[str, List[float]] = {
        "phase_1": [-1.0, 1.0],  # prefer weak_comment initially
        "phase_2": [-1.0, 1.0],  # prefer bad_fix initially
        "phase_3": [-0.5, 0.5],  # slight approve bias initially
    }

    baseline_reward = 0.0
    history = []
    epsilon_start = 0.35
    epsilon_end = 0.05
    warmup_episodes = max(10, args.episodes // 6)

    for episode in range(1, args.episodes + 1):
        chosen = {}
        action_plan = []

        for phase in ["phase_1", "phase_2", "phase_3"]:
            probs = softmax(logits[phase])
            progress = episode / max(1, args.episodes)
            epsilon = epsilon_start + (epsilon_end - epsilon_start) * progress

            if episode <= warmup_episodes:
                # Warmup: deliberately weak choices to create a measurable learning baseline.
                idx = 1 if len(probs) > 1 else 0
            elif random.random() < epsilon:
                idx = random.randrange(len(probs))
            else:
                idx = sample_index(probs)
            chosen[phase] = (idx, probs[idx])
            action_plan.append(catalog[phase][idx])

        total_reward, task_score, steps = env.run_episode(action_plan)

        advantage = total_reward - baseline_reward
        baseline_reward = 0.9 * baseline_reward + 0.1 * total_reward

        for phase in ["phase_1", "phase_2", "phase_3"]:
            idx, prob = chosen[phase]
            grad = (1.0 - prob)
            logits[phase][idx] += args.lr * advantage * grad
            # Soft penalty to non-chosen actions to make learning sharper.
            for j in range(len(logits[phase])):
                if j != idx:
                    logits[phase][j] -= args.lr * advantage * 0.15

        history.append(
            {
                "episode": episode,
                "reward": round(total_reward, 4),
                "task_score": round(task_score, 4),
                "steps": steps,
                "baseline_reward": round(baseline_reward, 4),
            }
        )

    metrics_path = args.log_dir / "train_metrics.csv"
    with metrics_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["episode", "reward", "task_score", "steps", "baseline_reward"])
        writer.writeheader()
        writer.writerows(history)

    # Also emit a compact summary for README use.
    summary_path = args.log_dir / "summary.txt"
    first = history[:10]
    last = history[-10:]
    first_avg = sum(x["reward"] for x in first) / max(1, len(first))
    last_avg = sum(x["reward"] for x in last) / max(1, len(last))
    with summary_path.open("w", encoding="utf-8") as f:
        f.write(f"episodes={args.episodes}\n")
        f.write(f"avg_reward_first10={first_avg:.4f}\n")
        f.write(f"avg_reward_last10={last_avg:.4f}\n")
        f.write(f"improvement={last_avg - first_avg:.4f}\n")

    print(f"Training completed. Metrics: {metrics_path}")
    print(f"Summary: {summary_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())