codearena / inference.py
adityanaikhpt's picture
Deploy: inference.py
44ca509 verified
import requests, csv, os, sys, time
from datetime import datetime
# Load config
sys.path.insert(0, os.path.dirname(__file__))
import config
LOG_FILE = os.path.join(os.path.dirname(__file__), "rewards_log.csv")
os.makedirs(os.path.join(os.path.dirname(__file__), "results"), exist_ok=True)
def get_fix(buggy_code: str) -> str:
prompt_system = (
"You are a Python debugging agent. "
"You will be given broken Python code. "
"Find the bug and fix it. "
"Return ONLY the corrected Python code. "
"No explanation. No markdown. No code blocks. Just raw Python."
)
if config.MODEL_PROVIDER == "openai":
import openai
client = openai.OpenAI(api_key=config.API_KEY, base_url=config.API_BASE_URL)
response = client.chat.completions.create(
model=config.MODEL_NAME,
messages=[
{"role": "system", "content": prompt_system},
{"role": "user", "content": f"Fix this code:\n\n{buggy_code}"}
],
temperature=0.2,
max_tokens=512
)
return response.choices[0].message.content.strip()
elif config.MODEL_PROVIDER == "huggingface":
from transformers import pipeline
pipe = pipeline("text-generation", model=config.MODEL_NAME, max_new_tokens=256)
result = pipe(f"Fix this Python bug:\n{buggy_code}\nFixed code:")
return result[0]["generated_text"].split("Fixed code:")[-1].strip()
elif config.MODEL_PROVIDER == "ollama":
response = requests.post(
"http://localhost:11434/api/generate",
json={"model": config.MODEL_NAME,
"prompt": f"{prompt_system}\n\nFix this code:\n{buggy_code}",
"stream": False}
)
return response.json()["response"].strip()
else:
raise ValueError(f"Unknown provider: {config.MODEL_PROVIDER}")
def run_training():
print(f"\n{'='*50}")
print(f"CodeArena Training Run")
print(f"Model: {config.MODEL_NAME} via {config.MODEL_PROVIDER}")
print(f"Episodes: {config.EPISODES} x {config.STEPS_PER_EPISODE} steps")
print(f"{'='*50}\n")
# Write CSV header
with open(LOG_FILE, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "episode", "step", "task_id",
"reward", "compile_score", "test_pass_ratio"
])
writer.writeheader()
all_rewards = []
for episode in range(config.EPISODES):
# Alternate between easy and medium for variety
difficulty = "easy" if episode % 3 != 2 else "medium"
reset_resp = requests.post(
f"{config.ENVIRONMENT_URL}/reset",
json={"task_id": difficulty}
).json()
obs = reset_resp["observation"]
task_id = reset_resp["task_id"]
episode_rewards = []
for step_num in range(config.STEPS_PER_EPISODE):
try:
fix = get_fix(obs["buggy_code"])
except Exception as e:
print(f" Model error: {e}")
fix = obs["buggy_code"] # fallback: send buggy code back
try:
result = requests.post(
f"{config.ENVIRONMENT_URL}/step",
json={"proposed_fix": fix},
timeout=30
).json()
except Exception as e:
print(f" Environment error: {e}")
break
reward = result["reward"]
components = result.get("reward_components", {})
episode_rewards.append(reward)
all_rewards.append(reward)
# Log to CSV
with open(LOG_FILE, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "episode", "step", "task_id",
"reward", "compile_score", "test_pass_ratio"
])
writer.writerow({
"timestamp": datetime.now().isoformat(),
"episode": episode,
"step": step_num,
"task_id": task_id,
"reward": reward,
"compile_score": components.get("compile_score", 0),
"test_pass_ratio": components.get("test_pass_ratio", 0)
})
print(f" Ep {episode:02d} Step {step_num} | "
f"reward={reward:.3f} | "
f"compile={components.get('compile_score',0):.1f} | "
f"tests={components.get('test_pass_ratio',0):.2f} | "
f"done={result['done']}")
if result["done"]:
break
obs = result["observation"]
time.sleep(0.5) # be polite to API
ep_avg = sum(episode_rewards) / len(episode_rewards) if episode_rewards else 0
print(f"Episode {episode:02d} done. Avg reward: {ep_avg:.3f}\n")
# Final summary
if all_rewards:
first10 = sum(all_rewards[:10]) / min(10, len(all_rewards))
last10 = sum(all_rewards[-10:]) / min(10, len(all_rewards))
improvement = last10 - first10
print(f"\n{'='*50}")
print(f"Training Complete")
print(f"First 10 steps avg reward : {first10:.3f}")
print(f"Last 10 steps avg reward : {last10:.3f}")
print(f"Improvement : {improvement:+.3f}")
print(f"Rewards logged to : {LOG_FILE}")
print(f"{'='*50}\n")
if __name__ == "__main__":
run_training()