""" Main reasoning agent - orchestrates the entire reproduction workflow. Uses hypothesis-driven approach to intelligently navigate the reproduction process. """ from typing import Dict, Any, Optional, Tuple, List import numpy as np from reproagent.environment import ReproAgentEnv from reproagent.state import ReproductionState, Phase from reproagent.actions import ActionSpace, ActionType, Action from reproagent.models import LLMClient from agents.paper_parser import PaperParser from agents.repo_analyzer import RepoAnalyzer from agents.debugger import Debugger class ReasoningAgent: """ Main intelligent agent for paper reproduction. Strategy: 1. Parse paper → understand what to reproduce 2. Find & analyze repo → understand how to reproduce 3. Setup environment → prepare for execution 4. Execute & debug → run code, fix errors 5. Experiment → tune hyperparameters 6. Compare → validate reproduction """ def __init__(self, env: ReproAgentEnv, use_llm: bool = True): """ Args: env: ReproAgent environment use_llm: Whether to use LLM for reasoning """ self.env = env self.action_space = ActionSpace() self.use_llm = use_llm # Initialize LLM and sub-agents if use_llm: try: self.llm = LLMClient() except: print("⚠️ LLM not available, using rule-based mode") self.llm = LLMClient(provider="mock") self.use_llm = False else: self.llm = LLMClient(provider="mock") self.paper_parser = PaperParser(self.llm) self.repo_analyzer = RepoAnalyzer(self.llm) self.debugger = Debugger(self.llm) # Agent state self.current_strategy = "systematic" # systematic, debugging, experimenting self.hypotheses = [] self.phase_progress = { Phase.PARSING: False, Phase.REPO_ANALYSIS: False, Phase.SETUP: False, Phase.EXECUTION: False, Phase.DEBUGGING: False, Phase.EXPERIMENTATION: False, } def select_action( self, observation: Dict[str, np.ndarray], info: Dict[str, Any] ) -> int: """ Select next action based on current state. Args: observation: Environment observation info: Additional info Returns: Action ID """ # Get current state from environment state = self.env.state # Determine strategy based on phase if state.meta.phase == Phase.IDLE or state.meta.phase == Phase.PARSING: return self._parsing_phase_action(state) elif state.meta.phase == Phase.REPO_ANALYSIS: return self._repo_analysis_action(state) elif state.meta.phase == Phase.SETUP: return self._setup_phase_action(state) elif state.meta.phase == Phase.EXECUTION: return self._execution_phase_action(state) elif state.meta.phase == Phase.DEBUGGING: return self._debugging_phase_action(state) elif state.meta.phase == Phase.EXPERIMENTATION: return self._experimentation_action(state) elif state.meta.phase == Phase.COMPARISON: if not getattr(state.meta, 'report_generated', False): return self.action_space.get_id_by_action(ActionType.GENERATE_REPORT) else: return self.action_space.get_id_by_action(ActionType.STOP_PROCESS) else: # Default: random exploration return self.env.action_space.sample() def _parsing_phase_action(self, state: ReproductionState) -> int: """Actions for paper parsing phase.""" if not state.paper.parsed: return self.action_space.get_id_by_action(ActionType.PARSE_PDF) elif not state.paper.github_links: return self.action_space.get_id_by_action(ActionType.EXTRACT_GITHUB) else: # Parsing is complete — move to repo cloning if not state.repo.cloned: return self.action_space.get_id_by_action(ActionType.CLONE_REPO) else: return self.action_space.get_id_by_action(ActionType.READ_README) def _repo_analysis_action(self, state: ReproductionState) -> int: """Actions for repository analysis phase.""" if not state.repo.cloned and state.paper.github_links: return self.action_space.get_id_by_action(ActionType.CLONE_REPO) elif state.repo.cloned and not state.repo.readme_content: return self.action_space.get_id_by_action(ActionType.READ_README) elif state.repo.readme_content and not state.repo.entry_point: return self.action_space.get_id_by_action(ActionType.FIND_ENTRY_POINT) elif state.repo.entry_point and not state.repo.dependencies: return self.action_space.get_id_by_action(ActionType.EXTRACT_DEPS) else: # Repo fully analyzed — move to environment setup (CREATE_VENV first!) return self.action_space.get_id_by_action(ActionType.CREATE_VENV) def _setup_phase_action(self, state: ReproductionState) -> int: """Actions for environment setup phase.""" if not state.environment.setup_complete: if state.repo.dependencies: return self.action_space.get_id_by_action(ActionType.INSTALL_REQUIREMENTS) else: # Even with no explicit deps listed, verify setup return self.action_space.get_id_by_action(ActionType.VERIFY_SETUP) else: # Setup complete — move to execution return self.action_space.get_id_by_action(ActionType.RUN_TRAINING) def _execution_phase_action(self, state: ReproductionState) -> int: """Actions for code execution phase.""" if state.execution.last_error: # Transition to debugging return self.action_space.get_id_by_action(ActionType.ANALYZE_ERROR) elif state.experiment.current_metric > 0 and state.experiment.gap > 0.05: # Has some results but gap is large — move to experimentation return self.action_space.get_id_by_action(ActionType.RUN_EXPERIMENT) elif state.experiment.current_metric > 0 and state.experiment.gap <= 0.05: # Close enough — compare return self.action_space.get_id_by_action(ActionType.COMPARE_RESULTS) else: # Run training return self.action_space.get_id_by_action(ActionType.RUN_TRAINING) def _debugging_phase_action(self, state: ReproductionState) -> int: """Actions for debugging phase.""" total_debug_actions = len(state.debug.fix_attempts) + len(state.debug.solutions_tried) # Cap: after 3 debug attempts, give up and compare what we have if total_debug_actions >= 3: state.debug.current_error = "" # clear to break loop return self.action_space.get_id_by_action(ActionType.COMPARE_RESULTS) if state.debug.current_error and not state.debug.last_hypothesis: return self.action_space.get_id_by_action(ActionType.ANALYZE_ERROR) elif state.debug.last_hypothesis and len(state.debug.fix_attempts) == 0: return self.action_space.get_id_by_action(ActionType.APPLY_FIX) elif state.debug.current_error: return self.action_space.get_id_by_action(ActionType.APPLY_FIX) else: # Error resolved — back to execution return self.action_space.get_id_by_action(ActionType.RUN_TRAINING) def _experimentation_action(self, state: ReproductionState) -> int: """Actions for hyperparameter tuning phase.""" gap = state.experiment.gap experiments_run = state.experiment.experiments_run # Use LLM for intelligent hyperparameter selection if available if self.use_llm and experiments_run > 0: action = self._llm_suggest_hyperparameter_action(state) if action is not None: return action # Rule-based: alternate between tuning a param and running an experiment if experiments_run > 0 and experiments_run % 2 == 0: # Every other step, run an experiment to measure progress return self.action_space.get_id_by_action(ActionType.RUN_EXPERIMENT) if gap > 0.3: return self.action_space.get_id_by_action(ActionType.MODIFY_LR) elif gap > 0.15: if experiments_run % 4 < 2: return self.action_space.get_id_by_action(ActionType.MODIFY_BATCH) else: return self.action_space.get_id_by_action(ActionType.MODIFY_OPTIMIZER) elif gap > 0.05: return self.action_space.get_id_by_action(ActionType.ADD_REGULARIZATION) else: # Very close — run experiment to lock in return self.action_space.get_id_by_action(ActionType.RUN_EXPERIMENT) def _llm_suggest_hyperparameter_action(self, state: ReproductionState) -> Optional[int]: """Use LLM to suggest next hyperparameter action.""" prompt = f""" You are tuning hyperparameters to reproduce a paper's results. Current state: - Target metric: {state.paper.target_metric:.3f} - Current metric: {state.experiment.current_metric:.3f} - Gap: {state.experiment.gap:.3f} - Experiments run: {state.experiment.experiments_run} - Current config: {state.experiment.current_config} What should be adjusted next? Options: 1. learning_rate 2. batch_size 3. optimizer 4. epochs 5. regularization 6. run_experiment (test current config) Respond with JSON: {{ "action": "learning_rate", "reasoning": "why this action" }} """ try: result = self.llm.generate_structured(prompt) action_name = result.get('action', '') action_map = { 'learning_rate': ActionType.MODIFY_LR, 'batch_size': ActionType.MODIFY_BATCH, 'optimizer': ActionType.MODIFY_OPTIMIZER, 'epochs': ActionType.MODIFY_EPOCHS, 'regularization': ActionType.ADD_REGULARIZATION, 'run_experiment': ActionType.RUN_EXPERIMENT } if action_name in action_map: action_type = action_map[action_name] return self.action_space.get_id_by_action(action_type) except Exception as e: print(f"⚠️ LLM suggestion failed: {e}") return None def form_hypothesis(self, state: ReproductionState) -> str: """ Form hypothesis about what's preventing reproduction. Args: state: Current state Returns: Hypothesis string """ if not state.paper.parsed: return "Need to parse paper to understand target" elif not state.repo.cloned: return "Need to find and clone repository" elif state.debug.current_error: return f"Need to fix error: {state.debug.current_error[:50]}" elif state.experiment.gap > 0.2: return "Hyperparameters are significantly off from optimal" elif state.experiment.gap > 0.05: return "Need fine-tuning of hyperparameters" else: return "Close to target, validating reproduction" def get_reasoning(self, state: ReproductionState, action_id: int) -> str: """ Generate human-readable reasoning for action. Args: state: Current state action_id: Selected action Returns: Reasoning string """ action_type = self.action_space.get_action_by_id(action_id) reasoning_map = { ActionType.PARSE_PDF: f"📄 Parsing paper to extract methodology", ActionType.EXTRACT_GITHUB: f"🔍 Looking for implementation repository", ActionType.CLONE_REPO: f"📥 Cloning repository: {state.paper.github_links[0] if state.paper.github_links else 'unknown'}", ActionType.READ_README: f"📖 Reading setup instructions", ActionType.INSTALL_REQUIREMENTS: f"📦 Installing {len(state.repo.dependencies)} dependencies", ActionType.RUN_TRAINING: f"🚀 Executing training script", ActionType.ANALYZE_ERROR: f"🔍 Analyzing error: {state.debug.current_error[:30]}...", ActionType.APPLY_FIX: f"🔧 Applying fix attempt #{len(state.debug.fix_attempts) + 1}", ActionType.RUN_EXPERIMENT: f"🧪 Running experiment #{state.experiment.experiments_run + 1}", ActionType.MODIFY_LR: f"⚙️ Adjusting learning rate (gap: {state.experiment.gap:.3f})", ActionType.COMPARE_RESULTS: f"📊 Comparing results: {state.experiment.current_metric:.3f} vs {state.paper.target_metric:.3f}", } return reasoning_map.get(action_type, f"Executing {action_type.value}") def reset(self): """Reset agent for new episode.""" self.current_strategy = "systematic" self.hypotheses = [] self.phase_progress = {phase: False for phase in Phase} def get_stats(self) -> Dict[str, Any]: """Get agent statistics.""" return { 'strategy': self.current_strategy, 'hypotheses_formed': len(self.hypotheses), 'phases_completed': sum(self.phase_progress.values()) } class RLAgent: """ RL-trainable agent (for PPO/DPO training). Uses neural network policy. """ def __init__(self, env: ReproAgentEnv, policy_network=None): """ Args: env: Environment policy_network: Pre-trained policy (optional) """ self.env = env self.policy = policy_network if policy_network is None: self._init_policy() def _init_policy(self): """Initialize policy network.""" try: import torch import torch.nn as nn # Simple MLP policy obs_dim = 25 # 5 feature vectors × 5 dims each action_dim = self.env.action_space.n self.policy = nn.Sequential( nn.Linear(obs_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim), nn.Softmax(dim=-1) ) except ImportError: print("⚠️ PyTorch not installed, using random policy") self.policy = None def select_action( self, observation: Dict[str, np.ndarray], info: Dict[str, Any] ) -> int: """Select action using policy network.""" if self.policy is None: return self.env.action_space.sample() try: import torch # Flatten observation obs_vec = np.concatenate([ observation['paper_features'], observation['repo_features'], observation['execution_features'], observation['experiment_features'], observation['meta_features'] ]) obs_tensor = torch.FloatTensor(obs_vec).unsqueeze(0) with torch.no_grad(): action_probs = self.policy(obs_tensor) # Sample action action = torch.multinomial(action_probs, 1).item() return action except: return self.env.action_space.sample() def reset(self): """Reset agent.""" pass def get_stats(self) -> Dict[str, Any]: """Get stats.""" return {'type': 'RL'} # Factory function def create_agent(env: ReproAgentEnv, agent_type: str = "reasoning", **kwargs): """ Factory function to create agents. Args: env: Environment agent_type: 'reasoning', 'rl', or 'random' **kwargs: Additional arguments Returns: Agent instance """ if agent_type == "reasoning": return ReasoningAgent(env, use_llm=kwargs.get('use_llm', True)) elif agent_type == "rl": return RLAgent(env, policy_network=kwargs.get('policy', None)) elif agent_type == "random": # Simple random agent for baseline class RandomAgent: def __init__(self, env): self.env = env def select_action(self, obs, info): return self.env.action_space.sample() def reset(self): pass def get_stats(self): return {'type': 'random'} def get_reasoning(self, state, action_id): return f"Random action: {action_id}" return RandomAgent(env) else: raise ValueError(f"Unknown agent type: {agent_type}") # Test if __name__ == "__main__": from reproagent.environment import ReproAgentEnv # Create environment env = ReproAgentEnv(difficulty="easy", use_llm=False) # Create agent agent = create_agent(env, agent_type="reasoning", use_llm=False) # Run episode obs, info = env.reset() for step in range(20): action = agent.select_action(obs, info) obs, reward, terminated, truncated, info = env.step(action) print(f"Step {step + 1}: {info.get('action_type', 'unknown')} | Reward: {reward:.2f}") if terminated or truncated: break print(f"\nFinal metric: {info.get('current_metric', 0.0):.3f}")