File size: 5,565 Bytes
44ca509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import requests, csv, os, sys, time
from datetime import datetime

# Load config
sys.path.insert(0, os.path.dirname(__file__))
import config

LOG_FILE = os.path.join(os.path.dirname(__file__), "rewards_log.csv")
os.makedirs(os.path.join(os.path.dirname(__file__), "results"), exist_ok=True)

def get_fix(buggy_code: str) -> str:
    prompt_system = (
        "You are a Python debugging agent. "
        "You will be given broken Python code. "
        "Find the bug and fix it. "
        "Return ONLY the corrected Python code. "
        "No explanation. No markdown. No code blocks. Just raw Python."
    )

    if config.MODEL_PROVIDER == "openai":
        import openai
        client = openai.OpenAI(api_key=config.API_KEY, base_url=config.API_BASE_URL)
        response = client.chat.completions.create(
            model=config.MODEL_NAME,
            messages=[
                {"role": "system", "content": prompt_system},
                {"role": "user", "content": f"Fix this code:\n\n{buggy_code}"}
            ],
            temperature=0.2,
            max_tokens=512
        )
        return response.choices[0].message.content.strip()

    elif config.MODEL_PROVIDER == "huggingface":
        from transformers import pipeline
        pipe = pipeline("text-generation", model=config.MODEL_NAME, max_new_tokens=256)
        result = pipe(f"Fix this Python bug:\n{buggy_code}\nFixed code:")
        return result[0]["generated_text"].split("Fixed code:")[-1].strip()

    elif config.MODEL_PROVIDER == "ollama":
        response = requests.post(
            "http://localhost:11434/api/generate",
            json={"model": config.MODEL_NAME,
                  "prompt": f"{prompt_system}\n\nFix this code:\n{buggy_code}",
                  "stream": False}
        )
        return response.json()["response"].strip()

    else:
        raise ValueError(f"Unknown provider: {config.MODEL_PROVIDER}")

def run_training():
    print(f"\n{'='*50}")
    print(f"CodeArena Training Run")
    print(f"Model: {config.MODEL_NAME} via {config.MODEL_PROVIDER}")
    print(f"Episodes: {config.EPISODES} x {config.STEPS_PER_EPISODE} steps")
    print(f"{'='*50}\n")

    # Write CSV header
    with open(LOG_FILE, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=[
            "timestamp", "episode", "step", "task_id",
            "reward", "compile_score", "test_pass_ratio"
        ])
        writer.writeheader()

    all_rewards = []

    for episode in range(config.EPISODES):
        # Alternate between easy and medium for variety
        difficulty = "easy" if episode % 3 != 2 else "medium"
        
        reset_resp = requests.post(
            f"{config.ENVIRONMENT_URL}/reset",
            json={"task_id": difficulty}
        ).json()

        obs = reset_resp["observation"]
        task_id = reset_resp["task_id"]
        episode_rewards = []

        for step_num in range(config.STEPS_PER_EPISODE):
            try:
                fix = get_fix(obs["buggy_code"])
            except Exception as e:
                print(f"  Model error: {e}")
                fix = obs["buggy_code"]  # fallback: send buggy code back

            try:
                result = requests.post(
                    f"{config.ENVIRONMENT_URL}/step",
                    json={"proposed_fix": fix},
                    timeout=30
                ).json()
            except Exception as e:
                print(f"  Environment error: {e}")
                break

            reward = result["reward"]
            components = result.get("reward_components", {})
            episode_rewards.append(reward)
            all_rewards.append(reward)

            # Log to CSV
            with open(LOG_FILE, "a", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=[
                    "timestamp", "episode", "step", "task_id",
                    "reward", "compile_score", "test_pass_ratio"
                ])
                writer.writerow({
                    "timestamp": datetime.now().isoformat(),
                    "episode": episode,
                    "step": step_num,
                    "task_id": task_id,
                    "reward": reward,
                    "compile_score": components.get("compile_score", 0),
                    "test_pass_ratio": components.get("test_pass_ratio", 0)
                })

            print(f"  Ep {episode:02d} Step {step_num} | "
                  f"reward={reward:.3f} | "
                  f"compile={components.get('compile_score',0):.1f} | "
                  f"tests={components.get('test_pass_ratio',0):.2f} | "
                  f"done={result['done']}")

            if result["done"]:
                break

            obs = result["observation"]
            time.sleep(0.5)  # be polite to API

        ep_avg = sum(episode_rewards) / len(episode_rewards) if episode_rewards else 0
        print(f"Episode {episode:02d} done. Avg reward: {ep_avg:.3f}\n")

    # Final summary
    if all_rewards:
        first10 = sum(all_rewards[:10]) / min(10, len(all_rewards))
        last10 = sum(all_rewards[-10:]) / min(10, len(all_rewards))
        improvement = last10 - first10
        print(f"\n{'='*50}")
        print(f"Training Complete")
        print(f"First 10 steps avg reward : {first10:.3f}")
        print(f"Last  10 steps avg reward : {last10:.3f}")
        print(f"Improvement               : {improvement:+.3f}")
        print(f"Rewards logged to         : {LOG_FILE}")
        print(f"{'='*50}\n")

if __name__ == "__main__":
    run_training()