Spaces:
Sleeping
Sleeping
| import requests, csv, os, sys, time | |
| from datetime import datetime | |
| # Load config | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import config | |
| LOG_FILE = os.path.join(os.path.dirname(__file__), "rewards_log.csv") | |
| os.makedirs(os.path.join(os.path.dirname(__file__), "results"), exist_ok=True) | |
| def get_fix(buggy_code: str) -> str: | |
| prompt_system = ( | |
| "You are a Python debugging agent. " | |
| "You will be given broken Python code. " | |
| "Find the bug and fix it. " | |
| "Return ONLY the corrected Python code. " | |
| "No explanation. No markdown. No code blocks. Just raw Python." | |
| ) | |
| if config.MODEL_PROVIDER == "openai": | |
| import openai | |
| client = openai.OpenAI(api_key=config.API_KEY, base_url=config.API_BASE_URL) | |
| response = client.chat.completions.create( | |
| model=config.MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": prompt_system}, | |
| {"role": "user", "content": f"Fix this code:\n\n{buggy_code}"} | |
| ], | |
| temperature=0.2, | |
| max_tokens=512 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| elif config.MODEL_PROVIDER == "huggingface": | |
| from transformers import pipeline | |
| pipe = pipeline("text-generation", model=config.MODEL_NAME, max_new_tokens=256) | |
| result = pipe(f"Fix this Python bug:\n{buggy_code}\nFixed code:") | |
| return result[0]["generated_text"].split("Fixed code:")[-1].strip() | |
| elif config.MODEL_PROVIDER == "ollama": | |
| response = requests.post( | |
| "http://localhost:11434/api/generate", | |
| json={"model": config.MODEL_NAME, | |
| "prompt": f"{prompt_system}\n\nFix this code:\n{buggy_code}", | |
| "stream": False} | |
| ) | |
| return response.json()["response"].strip() | |
| else: | |
| raise ValueError(f"Unknown provider: {config.MODEL_PROVIDER}") | |
| def run_training(): | |
| print(f"\n{'='*50}") | |
| print(f"CodeArena Training Run") | |
| print(f"Model: {config.MODEL_NAME} via {config.MODEL_PROVIDER}") | |
| print(f"Episodes: {config.EPISODES} x {config.STEPS_PER_EPISODE} steps") | |
| print(f"{'='*50}\n") | |
| # Write CSV header | |
| with open(LOG_FILE, "w", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=[ | |
| "timestamp", "episode", "step", "task_id", | |
| "reward", "compile_score", "test_pass_ratio" | |
| ]) | |
| writer.writeheader() | |
| all_rewards = [] | |
| for episode in range(config.EPISODES): | |
| # Alternate between easy and medium for variety | |
| difficulty = "easy" if episode % 3 != 2 else "medium" | |
| reset_resp = requests.post( | |
| f"{config.ENVIRONMENT_URL}/reset", | |
| json={"task_id": difficulty} | |
| ).json() | |
| obs = reset_resp["observation"] | |
| task_id = reset_resp["task_id"] | |
| episode_rewards = [] | |
| for step_num in range(config.STEPS_PER_EPISODE): | |
| try: | |
| fix = get_fix(obs["buggy_code"]) | |
| except Exception as e: | |
| print(f" Model error: {e}") | |
| fix = obs["buggy_code"] # fallback: send buggy code back | |
| try: | |
| result = requests.post( | |
| f"{config.ENVIRONMENT_URL}/step", | |
| json={"proposed_fix": fix}, | |
| timeout=30 | |
| ).json() | |
| except Exception as e: | |
| print(f" Environment error: {e}") | |
| break | |
| reward = result["reward"] | |
| components = result.get("reward_components", {}) | |
| episode_rewards.append(reward) | |
| all_rewards.append(reward) | |
| # Log to CSV | |
| with open(LOG_FILE, "a", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=[ | |
| "timestamp", "episode", "step", "task_id", | |
| "reward", "compile_score", "test_pass_ratio" | |
| ]) | |
| writer.writerow({ | |
| "timestamp": datetime.now().isoformat(), | |
| "episode": episode, | |
| "step": step_num, | |
| "task_id": task_id, | |
| "reward": reward, | |
| "compile_score": components.get("compile_score", 0), | |
| "test_pass_ratio": components.get("test_pass_ratio", 0) | |
| }) | |
| print(f" Ep {episode:02d} Step {step_num} | " | |
| f"reward={reward:.3f} | " | |
| f"compile={components.get('compile_score',0):.1f} | " | |
| f"tests={components.get('test_pass_ratio',0):.2f} | " | |
| f"done={result['done']}") | |
| if result["done"]: | |
| break | |
| obs = result["observation"] | |
| time.sleep(0.5) # be polite to API | |
| ep_avg = sum(episode_rewards) / len(episode_rewards) if episode_rewards else 0 | |
| print(f"Episode {episode:02d} done. Avg reward: {ep_avg:.3f}\n") | |
| # Final summary | |
| if all_rewards: | |
| first10 = sum(all_rewards[:10]) / min(10, len(all_rewards)) | |
| last10 = sum(all_rewards[-10:]) / min(10, len(all_rewards)) | |
| improvement = last10 - first10 | |
| print(f"\n{'='*50}") | |
| print(f"Training Complete") | |
| print(f"First 10 steps avg reward : {first10:.3f}") | |
| print(f"Last 10 steps avg reward : {last10:.3f}") | |
| print(f"Improvement : {improvement:+.3f}") | |
| print(f"Rewards logged to : {LOG_FILE}") | |
| print(f"{'='*50}\n") | |
| if __name__ == "__main__": | |
| run_training() | |