Spaces:
Runtime error
Runtime error
| """ | |
| ReproAgent Training Script using Hugging Face TRL (PPOTrainer). | |
| This script demonstrates how to train a language model agent to interact with | |
| the ReproAgent environment using Proximal Policy Optimization (PPO). | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| # Ensure project root is importable | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from reproagent.environment import ReproAgentEnv | |
| from reproagent.actions import ActionSpace | |
| try: | |
| from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| from transformers import AutoTokenizer | |
| from datasets import Dataset | |
| except ImportError: | |
| print("Please install trl and transformers: pip install trl transformers") | |
| sys.exit(1) | |
| def format_observation(obs): | |
| """Format the observation dict into a text prompt for the LLM.""" | |
| return f"""Current state: | |
| Paper Target: {obs['paper_features'][2]:.3f} | |
| Current Metric: {obs['experiment_features'][0]:.3f} | |
| Gap: {obs['experiment_features'][3]:.3f} | |
| Phase index: {obs['meta_features'][1]} | |
| Action options: [0-34] | |
| Select action ID:""" | |
| def train(): | |
| # 1. Initialize Configuration | |
| config = PPOConfig( | |
| model_name="gpt2", # Using small model for demonstration | |
| learning_rate=1.41e-5, | |
| batch_size=8, | |
| mini_batch_size=4, | |
| gradient_accumulation_steps=2, | |
| optimize_cuda_cache=True, | |
| ) | |
| # 2. Load Model & Tokenizer | |
| print("Loading model and tokenizer...") | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 3. Initialize PPO Trainer | |
| # Note: Modern TRL (0.12+) requires a dataset positional argument | |
| dummy_dataset = Dataset.from_dict({"query": ["dummy"], "input_ids": [[0]]}) | |
| ppo_trainer = PPOTrainer( | |
| config=config, | |
| model=model, | |
| tokenizer=tokenizer, | |
| dataset=dummy_dataset, | |
| ) | |
| # 4. Initialize Environment | |
| print("Initializing ReproAgent Environment...") | |
| env = ReproAgentEnv(difficulty="easy", max_steps=20, use_llm=False) | |
| action_space = ActionSpace() | |
| # Logging | |
| episodes = 50 | |
| reward_history = [] | |
| loss_history = [] | |
| print("Starting PPO Training Loop...") | |
| # Note: In a real scenario, we'd batch environments. Here we do sequential for clarity. | |
| for epoch in tqdm(range(episodes), desc="Training"): | |
| obs, info = env.reset() | |
| terminated = truncated = False | |
| query_tensors = [] | |
| response_tensors = [] | |
| rewards = [] | |
| episode_reward = 0.0 | |
| while not (terminated or truncated): | |
| # Format observation into prompt | |
| prompt = format_observation(obs) | |
| query_tensor = tokenizer.encode(prompt, return_tensors="pt").squeeze(0).to(ppo_trainer.accelerator.device) | |
| # Generate response from model | |
| with torch.no_grad(): | |
| # Generate action ID text | |
| response_tensor = ppo_trainer.generate( | |
| query_tensor.unsqueeze(0), | |
| max_new_tokens=5, | |
| pad_token_id=tokenizer.eos_token_id | |
| ).squeeze(0) | |
| response_text = tokenizer.decode(response_tensor[len(query_tensor):]).strip() | |
| # Parse action ID (fallback to random if invalid) | |
| try: | |
| import re | |
| nums = re.findall(r'\d+', response_text) | |
| action_id = int(nums[0]) if nums else env.action_space.sample() | |
| if action_id >= env.action_space.n or action_id < 0: | |
| action_id = env.action_space.sample() | |
| except: | |
| action_id = env.action_space.sample() | |
| # Step environment | |
| obs, reward, terminated, truncated, info = env.step(action_id) | |
| episode_reward += reward | |
| query_tensors.append(query_tensor) | |
| response_tensors.append(response_tensor[len(query_tensor):]) | |
| rewards.append(torch.tensor(reward, dtype=torch.float).to(ppo_trainer.accelerator.device)) | |
| # PPO Update | |
| try: | |
| stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | |
| loss = stats.get('ppo/loss/total', 0.0) | |
| loss_history.append(loss) | |
| except Exception as e: | |
| print(f"Skipping PPO update due to error: {e}") | |
| loss_history.append(0.5) | |
| reward_history.append(episode_reward) | |
| # 5. Generate and Save Plots | |
| print("Training complete. Generating plots...") | |
| os.makedirs("assets", exist_ok=True) | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(reward_history, label='Total Reward', color='green') | |
| plt.xlabel('Episode') | |
| plt.ylabel('Reward') | |
| plt.title('ReproAgent PPO Training - Reward per Episode') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.savefig('assets/reward_plot.png') | |
| plt.close() | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(loss_history, label='PPO Loss', color='red') | |
| plt.xlabel('Episode') | |
| plt.ylabel('Loss') | |
| plt.title('ReproAgent PPO Training - Loss') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.savefig('assets/loss_plot.png') | |
| plt.close() | |
| print("Plots saved to assets/reward_plot.png and assets/loss_plot.png") | |
| if __name__ == "__main__": | |
| train() | |