adityanaikhpt commited on
Commit
44ca509
·
verified ·
1 Parent(s): e680fbd

Deploy: inference.py

Browse files
Files changed (1) hide show
  1. inference.py +149 -0
inference.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, csv, os, sys, time
2
+ from datetime import datetime
3
+
4
+ # Load config
5
+ sys.path.insert(0, os.path.dirname(__file__))
6
+ import config
7
+
8
+ LOG_FILE = os.path.join(os.path.dirname(__file__), "rewards_log.csv")
9
+ os.makedirs(os.path.join(os.path.dirname(__file__), "results"), exist_ok=True)
10
+
11
+ def get_fix(buggy_code: str) -> str:
12
+ prompt_system = (
13
+ "You are a Python debugging agent. "
14
+ "You will be given broken Python code. "
15
+ "Find the bug and fix it. "
16
+ "Return ONLY the corrected Python code. "
17
+ "No explanation. No markdown. No code blocks. Just raw Python."
18
+ )
19
+
20
+ if config.MODEL_PROVIDER == "openai":
21
+ import openai
22
+ client = openai.OpenAI(api_key=config.API_KEY, base_url=config.API_BASE_URL)
23
+ response = client.chat.completions.create(
24
+ model=config.MODEL_NAME,
25
+ messages=[
26
+ {"role": "system", "content": prompt_system},
27
+ {"role": "user", "content": f"Fix this code:\n\n{buggy_code}"}
28
+ ],
29
+ temperature=0.2,
30
+ max_tokens=512
31
+ )
32
+ return response.choices[0].message.content.strip()
33
+
34
+ elif config.MODEL_PROVIDER == "huggingface":
35
+ from transformers import pipeline
36
+ pipe = pipeline("text-generation", model=config.MODEL_NAME, max_new_tokens=256)
37
+ result = pipe(f"Fix this Python bug:\n{buggy_code}\nFixed code:")
38
+ return result[0]["generated_text"].split("Fixed code:")[-1].strip()
39
+
40
+ elif config.MODEL_PROVIDER == "ollama":
41
+ response = requests.post(
42
+ "http://localhost:11434/api/generate",
43
+ json={"model": config.MODEL_NAME,
44
+ "prompt": f"{prompt_system}\n\nFix this code:\n{buggy_code}",
45
+ "stream": False}
46
+ )
47
+ return response.json()["response"].strip()
48
+
49
+ else:
50
+ raise ValueError(f"Unknown provider: {config.MODEL_PROVIDER}")
51
+
52
+ def run_training():
53
+ print(f"\n{'='*50}")
54
+ print(f"CodeArena Training Run")
55
+ print(f"Model: {config.MODEL_NAME} via {config.MODEL_PROVIDER}")
56
+ print(f"Episodes: {config.EPISODES} x {config.STEPS_PER_EPISODE} steps")
57
+ print(f"{'='*50}\n")
58
+
59
+ # Write CSV header
60
+ with open(LOG_FILE, "w", newline="") as f:
61
+ writer = csv.DictWriter(f, fieldnames=[
62
+ "timestamp", "episode", "step", "task_id",
63
+ "reward", "compile_score", "test_pass_ratio"
64
+ ])
65
+ writer.writeheader()
66
+
67
+ all_rewards = []
68
+
69
+ for episode in range(config.EPISODES):
70
+ # Alternate between easy and medium for variety
71
+ difficulty = "easy" if episode % 3 != 2 else "medium"
72
+
73
+ reset_resp = requests.post(
74
+ f"{config.ENVIRONMENT_URL}/reset",
75
+ json={"task_id": difficulty}
76
+ ).json()
77
+
78
+ obs = reset_resp["observation"]
79
+ task_id = reset_resp["task_id"]
80
+ episode_rewards = []
81
+
82
+ for step_num in range(config.STEPS_PER_EPISODE):
83
+ try:
84
+ fix = get_fix(obs["buggy_code"])
85
+ except Exception as e:
86
+ print(f" Model error: {e}")
87
+ fix = obs["buggy_code"] # fallback: send buggy code back
88
+
89
+ try:
90
+ result = requests.post(
91
+ f"{config.ENVIRONMENT_URL}/step",
92
+ json={"proposed_fix": fix},
93
+ timeout=30
94
+ ).json()
95
+ except Exception as e:
96
+ print(f" Environment error: {e}")
97
+ break
98
+
99
+ reward = result["reward"]
100
+ components = result.get("reward_components", {})
101
+ episode_rewards.append(reward)
102
+ all_rewards.append(reward)
103
+
104
+ # Log to CSV
105
+ with open(LOG_FILE, "a", newline="") as f:
106
+ writer = csv.DictWriter(f, fieldnames=[
107
+ "timestamp", "episode", "step", "task_id",
108
+ "reward", "compile_score", "test_pass_ratio"
109
+ ])
110
+ writer.writerow({
111
+ "timestamp": datetime.now().isoformat(),
112
+ "episode": episode,
113
+ "step": step_num,
114
+ "task_id": task_id,
115
+ "reward": reward,
116
+ "compile_score": components.get("compile_score", 0),
117
+ "test_pass_ratio": components.get("test_pass_ratio", 0)
118
+ })
119
+
120
+ print(f" Ep {episode:02d} Step {step_num} | "
121
+ f"reward={reward:.3f} | "
122
+ f"compile={components.get('compile_score',0):.1f} | "
123
+ f"tests={components.get('test_pass_ratio',0):.2f} | "
124
+ f"done={result['done']}")
125
+
126
+ if result["done"]:
127
+ break
128
+
129
+ obs = result["observation"]
130
+ time.sleep(0.5) # be polite to API
131
+
132
+ ep_avg = sum(episode_rewards) / len(episode_rewards) if episode_rewards else 0
133
+ print(f"Episode {episode:02d} done. Avg reward: {ep_avg:.3f}\n")
134
+
135
+ # Final summary
136
+ if all_rewards:
137
+ first10 = sum(all_rewards[:10]) / min(10, len(all_rewards))
138
+ last10 = sum(all_rewards[-10:]) / min(10, len(all_rewards))
139
+ improvement = last10 - first10
140
+ print(f"\n{'='*50}")
141
+ print(f"Training Complete")
142
+ print(f"First 10 steps avg reward : {first10:.3f}")
143
+ print(f"Last 10 steps avg reward : {last10:.3f}")
144
+ print(f"Improvement : {improvement:+.3f}")
145
+ print(f"Rewards logged to : {LOG_FILE}")
146
+ print(f"{'='*50}\n")
147
+
148
+ if __name__ == "__main__":
149
+ run_training()