Yusufarsh commited on
Commit
c8d0576
·
verified ·
1 Parent(s): 358b88c

Upload 13 files

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
agents/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent implementations for ReproAgent.
3
+ """
4
+
5
+ from agents.paper_parser import PaperParser
6
+ from agents.repo_analyzer import RepoAnalyzer
7
+ from agents.debugger import Debugger
8
+ from agents.reasoning_agent import ReasoningAgent
9
+
10
+ __all__ = [
11
+ 'PaperParser',
12
+ 'RepoAnalyzer',
13
+ 'Debugger',
14
+ 'ReasoningAgent'
15
+ ]
agents/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (482 Bytes). View file
 
agents/__pycache__/debugger.cpython-312.pyc ADDED
Binary file (9.31 kB). View file
 
agents/__pycache__/paper_parser.cpython-312.pyc ADDED
Binary file (11.9 kB). View file
 
agents/__pycache__/reasoning_agent.cpython-312.pyc ADDED
Binary file (24.6 kB). View file
 
agents/__pycache__/repo_analyzer.cpython-312.pyc ADDED
Binary file (12.9 kB). View file
 
agents/debugger.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Debugging agent - analyzes and fixes code errors.
3
+ """
4
+
5
+ import re
6
+ from typing import Dict, Any, List, Optional, Tuple
7
+
8
+ from reproagent.models import LLMClient
9
+
10
+
11
+ class Debugger:
12
+ """
13
+ Debugging agent that:
14
+ 1. Analyzes error messages
15
+ 2. Searches for solutions
16
+ 3. Proposes fixes
17
+ 4. Applies patches
18
+ """
19
+
20
+ def __init__(self, llm_client: LLMClient):
21
+ """
22
+ Args:
23
+ llm_client: LLM for error analysis
24
+ """
25
+ self.llm = llm_client
26
+
27
+ # Common error patterns
28
+ self.error_patterns = {
29
+ 'ImportError': r'ImportError: No module named [\'"](.+)[\'"]',
30
+ 'ModuleNotFoundError': r'ModuleNotFoundError: No module named [\'"](.+)[\'"]',
31
+ 'FileNotFoundError': r'FileNotFoundError: \[Errno 2\] No such file or directory: [\'"](.+)[\'"]',
32
+ 'RuntimeError': r'RuntimeError: (.+)',
33
+ 'ValueError': r'ValueError: (.+)',
34
+ 'TypeError': r'TypeError: (.+)',
35
+ 'AttributeError': r'AttributeError: (.+)',
36
+ }
37
+
38
+ def analyze_error(self, error_message: str, code_context: Optional[str] = None) -> Dict[str, Any]:
39
+ """
40
+ Analyze error and determine cause.
41
+
42
+ Args:
43
+ error_message: Full error message/traceback
44
+ code_context: Relevant code snippet (optional)
45
+
46
+ Returns:
47
+ Analysis dict with error type, cause, and suggested fixes
48
+ """
49
+ print(f"🔍 Analyzing error...")
50
+
51
+ # Classify error type
52
+ error_type = self._classify_error(error_message)
53
+
54
+ # Extract error details
55
+ error_details = self._extract_error_details(error_message, error_type)
56
+
57
+ # Get LLM analysis
58
+ llm_analysis = self._llm_analyze_error(error_message, code_context)
59
+
60
+ analysis = {
61
+ 'error_type': error_type,
62
+ 'error_details': error_details,
63
+ 'root_cause': llm_analysis.get('root_cause', 'Unknown'),
64
+ 'suggested_fixes': llm_analysis.get('fixes', []),
65
+ 'confidence': llm_analysis.get('confidence', 0.5)
66
+ }
67
+
68
+ print(f"✅ Error analyzed: {error_type}")
69
+ print(f" Cause: {analysis['root_cause']}")
70
+
71
+ return analysis
72
+
73
+ def _classify_error(self, error_message: str) -> str:
74
+ """Classify error type."""
75
+ for error_type, pattern in self.error_patterns.items():
76
+ if re.search(pattern, error_message):
77
+ return error_type
78
+
79
+ # Check for common error types in message
80
+ if 'import' in error_message.lower():
81
+ return 'ImportError'
82
+ elif 'file' in error_message.lower() and 'not found' in error_message.lower():
83
+ return 'FileNotFoundError'
84
+ elif 'cuda' in error_message.lower() or 'gpu' in error_message.lower():
85
+ return 'CUDAError'
86
+ elif 'memory' in error_message.lower():
87
+ return 'MemoryError'
88
+
89
+ return 'UnknownError'
90
+
91
+ def _extract_error_details(self, error_message: str, error_type: str) -> Dict[str, str]:
92
+ """Extract specific details from error."""
93
+ details = {}
94
+
95
+ if error_type in self.error_patterns:
96
+ pattern = self.error_patterns[error_type]
97
+ match = re.search(pattern, error_message)
98
+ if match:
99
+ details['detail'] = match.group(1)
100
+
101
+ # Extract file and line number
102
+ file_pattern = r'File "(.+)", line (\d+)'
103
+ file_match = re.search(file_pattern, error_message)
104
+ if file_match:
105
+ details['file'] = file_match.group(1)
106
+ details['line'] = file_match.group(2)
107
+
108
+ return details
109
+
110
+ def _llm_analyze_error(self, error_message: str, code_context: Optional[str]) -> Dict[str, Any]:
111
+ """Use LLM to analyze error."""
112
+
113
+ prompt = f"""
114
+ Analyze this Python error and provide solutions.
115
+
116
+ Error:
117
+ {error_message[:1000]}
118
+ """
119
+
120
+ if code_context:
121
+ prompt += f"\n\nRelevant code:\n{code_context[:500]}"
122
+
123
+ prompt += """
124
+
125
+ Respond with JSON:
126
+ {
127
+ "root_cause": "explanation of what caused the error",
128
+ "fixes": ["fix 1", "fix 2", "fix 3"],
129
+ "confidence": 0.9
130
+ }
131
+ """
132
+
133
+ try:
134
+ result = self.llm.generate_structured(prompt)
135
+ return result
136
+ except:
137
+ return self._fallback_analysis(error_message)
138
+
139
+ def _fallback_analysis(self, error_message: str) -> Dict[str, Any]:
140
+ """Fallback analysis without LLM."""
141
+
142
+ # Common fixes for common errors
143
+ fixes = []
144
+
145
+ if 'ModuleNotFoundError' in error_message or 'ImportError' in error_message:
146
+ match = re.search(r"module named ['\"](.+)['\"]", error_message)
147
+ if match:
148
+ module = match.group(1)
149
+ fixes = [
150
+ f"Install missing package: pip install {module}",
151
+ f"Check if {module} is in requirements.txt",
152
+ "Activate correct virtual environment"
153
+ ]
154
+
155
+ elif 'FileNotFoundError' in error_message:
156
+ fixes = [
157
+ "Check if file path is correct",
158
+ "Ensure data is downloaded",
159
+ "Check working directory"
160
+ ]
161
+
162
+ elif 'CUDA' in error_message or 'GPU' in error_message:
163
+ fixes = [
164
+ "Check CUDA installation",
165
+ "Verify GPU availability",
166
+ "Try running on CPU: device='cpu'"
167
+ ]
168
+
169
+ elif 'memory' in error_message.lower():
170
+ fixes = [
171
+ "Reduce batch size",
172
+ "Use gradient accumulation",
173
+ "Clear GPU cache: torch.cuda.empty_cache()"
174
+ ]
175
+
176
+ return {
177
+ 'root_cause': 'Error detected',
178
+ 'fixes': fixes or ['Debug manually'],
179
+ 'confidence': 0.6
180
+ }
181
+
182
+ def generate_fix(self, error_analysis: Dict[str, Any]) -> str:
183
+ """
184
+ Generate code fix based on error analysis.
185
+
186
+ Args:
187
+ error_analysis: Output from analyze_error()
188
+
189
+ Returns:
190
+ Fix as code or command
191
+ """
192
+ error_type = error_analysis['error_type']
193
+ details = error_analysis['error_details']
194
+
195
+ # Generate specific fix based on error type
196
+ if error_type in ['ImportError', 'ModuleNotFoundError']:
197
+ module = details.get('detail', '')
198
+ return f"pip install {module}"
199
+
200
+ elif error_type == 'FileNotFoundError':
201
+ file_path = details.get('detail', '')
202
+ return f"# Check if {file_path} exists or download it"
203
+
204
+ elif error_type == 'CUDAError':
205
+ return "# Try: model.to('cpu') or install CUDA"
206
+
207
+ elif error_type == 'MemoryError':
208
+ return "# Reduce batch_size or use gradient accumulation"
209
+
210
+ # Use LLM for complex fixes
211
+ return self._llm_generate_fix(error_analysis)
212
+
213
+ def _llm_generate_fix(self, error_analysis: Dict[str, Any]) -> str:
214
+ """Use LLM to generate code fix."""
215
+
216
+ prompt = f"""
217
+ Generate a code fix for this error:
218
+
219
+ Error Type: {error_analysis['error_type']}
220
+ Root Cause: {error_analysis['root_cause']}
221
+
222
+ Provide the fix as Python code or shell command.
223
+ """
224
+
225
+ try:
226
+ fix = self.llm.generate(prompt, max_tokens=200)
227
+ return fix.strip()
228
+ except:
229
+ return "# Manual fix required"
230
+
231
+ def search_solution(self, error_message: str) -> List[str]:
232
+ """
233
+ Search for solutions to error.
234
+ Simulates searching StackOverflow, documentation, etc.
235
+
236
+ Args:
237
+ error_message: Error message
238
+
239
+ Returns:
240
+ List of solution suggestions
241
+ """
242
+ # In full implementation, would search:
243
+ # - StackOverflow API
244
+ # - GitHub Issues
245
+ # - Documentation
246
+
247
+ # For now, use LLM to generate solutions
248
+ prompt = f"""
249
+ This error occurred: {error_message[:500]}
250
+
251
+ List 3 common solutions to this error.
252
+ Respond with JSON:
253
+ {{
254
+ "solutions": ["solution 1", "solution 2", "solution 3"]
255
+ }}
256
+ """
257
+
258
+ try:
259
+ result = self.llm.generate_structured(prompt)
260
+ return result.get('solutions', [])
261
+ except:
262
+ return ["Check dependencies", "Review code", "Search documentation"]
263
+
264
+
265
+ # Test
266
+ if __name__ == "__main__":
267
+ from reproagent.models import LLMClient
268
+
269
+ llm = LLMClient()
270
+ debugger = Debugger(llm)
271
+
272
+ # Test error
273
+ error = """
274
+ Traceback (most recent call last):
275
+ File "train.py", line 10, in <module>
276
+ import torch
277
+ ModuleNotFoundError: No module named 'torch'
278
+ """
279
+
280
+ analysis = debugger.analyze_error(error)
281
+ print(analysis)
282
+
283
+ fix = debugger.generate_fix(analysis)
284
+ print(f"\nFix: {fix}")
agents/paper_parser.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paper parsing agent - extracts structured information from PDFs.
3
+ """
4
+
5
+ import re
6
+ from typing import Dict, Any, List, Optional
7
+ from pathlib import Path
8
+
9
+ from reproagent.models import LLMClient
10
+ from reproagent.state import PaperState
11
+
12
+
13
+ class PaperParser:
14
+ """
15
+ Parses research papers and extracts key information.
16
+ Uses LLM to extract structured data from paper text.
17
+ """
18
+
19
+ def __init__(self, llm_client: LLMClient):
20
+ """
21
+ Args:
22
+ llm_client: LLM client for extraction
23
+ """
24
+ self.llm = llm_client
25
+
26
+ def parse_paper(self, pdf_path: str) -> PaperState:
27
+ """
28
+ Parse paper and extract structured information.
29
+
30
+ Args:
31
+ pdf_path: Path to PDF file
32
+
33
+ Returns:
34
+ PaperState with extracted info
35
+ """
36
+ print(f"📄 Parsing paper: {pdf_path}")
37
+
38
+ # Extract text from PDF
39
+ text = self._extract_text(pdf_path)
40
+
41
+ if not text or text.startswith("Error"):
42
+ print(f"❌ Failed to extract text from PDF")
43
+ return PaperState(pdf_path=pdf_path)
44
+
45
+ print(f"✅ Extracted {len(text)} characters")
46
+
47
+ # Extract structured info with LLM
48
+ extracted = self._extract_with_llm(text)
49
+
50
+ # Build PaperState
51
+ state = PaperState(
52
+ pdf_path=pdf_path,
53
+ title=extracted.get('title', ''),
54
+ abstract=extracted.get('abstract', ''),
55
+ dataset=extracted.get('dataset', ''),
56
+ model=extracted.get('model', ''),
57
+ target_metric=float(extracted.get('target_metric', 0.0)),
58
+ metric_name=extracted.get('metric_name', 'accuracy'),
59
+ github_links=extracted.get('github_links', []),
60
+ key_claims=extracted.get('key_claims', []),
61
+ parsed=True,
62
+ confidence=extracted.get('confidence', 0.8)
63
+ )
64
+
65
+ print(f"✅ Paper parsed: {state.title}")
66
+ print(f" Dataset: {state.dataset}")
67
+ print(f" Model: {state.model}")
68
+ print(f" Target: {state.target_metric} {state.metric_name}")
69
+
70
+ return state
71
+
72
+ def _extract_text(self, pdf_path: str) -> str:
73
+ """
74
+ Extract text from PDF.
75
+ Tries multiple methods.
76
+ """
77
+ try:
78
+ # Try PyPDF2 first (faster)
79
+ import PyPDF2
80
+
81
+ with open(pdf_path, 'rb') as file:
82
+ reader = PyPDF2.PdfReader(file)
83
+ text = ""
84
+ # Extract first 10 pages
85
+ for page in reader.pages[:10]:
86
+ text += page.extract_text() + "\n"
87
+ return text
88
+
89
+ except Exception as e:
90
+ print(f"⚠️ PyPDF2 failed: {e}")
91
+
92
+ try:
93
+ # Try pdfplumber (more accurate)
94
+ import pdfplumber
95
+
96
+ text = ""
97
+ with pdfplumber.open(pdf_path) as pdf:
98
+ for page in pdf.pages[:10]:
99
+ text += page.extract_text() + "\n"
100
+ return text
101
+
102
+ except Exception as e2:
103
+ print(f"⚠️ pdfplumber failed: {e2}")
104
+ return f"Error: Could not extract text from PDF"
105
+
106
+ def _extract_with_llm(self, text: str) -> Dict[str, Any]:
107
+ """
108
+ Use LLM to extract structured information.
109
+
110
+ Args:
111
+ text: Paper text
112
+
113
+ Returns:
114
+ Extracted information dict
115
+ """
116
+ # Truncate text to fit in context
117
+ text_sample = text[:5000]
118
+
119
+ prompt = f"""
120
+ Extract the following information from this research paper:
121
+
122
+ 1. **Title**: Full paper title
123
+ 2. **Abstract**: Paper abstract (if present)
124
+ 3. **Dataset**: Dataset used (e.g., "CIFAR-10", "ImageNet", "COCO")
125
+ 4. **Model**: Model architecture (e.g., "ResNet-50", "BERT", "GPT-2")
126
+ 5. **Target Metric**: Best reported performance value as a number. Extract exactly what is in the text.
127
+ 6. **Metric Name**: Type of metric (e.g., "FID", "accuracy", "CLIP score", "BLEU"). DO NOT default to accuracy!
128
+ 7. **GitHub Links**: Any GitHub URLs mentioned (full URLs)
129
+ 8. **Key Claims**: Main performance claims (list)
130
+
131
+ Paper excerpt:
132
+ {text_sample}
133
+
134
+ Respond with ONLY valid JSON in this exact format:
135
+ {{
136
+ "title": "paper title here",
137
+ "abstract": "abstract text here",
138
+ "dataset": "dataset name",
139
+ "model": "model name",
140
+ "target_metric": 12.34,
141
+ "metric_name": "FID",
142
+ "github_links": ["https://github.com/user/repo"],
143
+ "key_claims": ["claim 1", "claim 2"],
144
+ "confidence": 0.9
145
+ }}
146
+ """
147
+
148
+ try:
149
+ result = self.llm.generate_structured(prompt)
150
+
151
+ # Validate and clean result
152
+ if 'error' not in result:
153
+ # Ensure github_links is a list
154
+ if 'github_links' in result and isinstance(result['github_links'], str):
155
+ result['github_links'] = [result['github_links']]
156
+
157
+ # Extract GitHub links from text if none found
158
+ if not result.get('github_links'):
159
+ result['github_links'] = self._extract_github_links(text)
160
+
161
+ return result
162
+ else:
163
+ print(f"⚠️ LLM extraction failed: {result.get('error')}")
164
+
165
+ except Exception as e:
166
+ print(f"⚠️ LLM error: {e}")
167
+
168
+ # Fallback: regex extraction
169
+ return self._fallback_extraction(text)
170
+
171
+ def _extract_github_links(self, text: str) -> List[str]:
172
+ """Extract GitHub URLs using regex."""
173
+ pattern = r'https?://github\.com/[\w\-]+/[\w\-]+'
174
+ matches = re.findall(pattern, text)
175
+ return list(set(matches)) # unique links
176
+
177
+ def _fallback_extraction(self, text: str) -> Dict[str, Any]:
178
+ """
179
+ Fallback extraction using simple heuristics.
180
+ Used when LLM fails.
181
+ """
182
+ print("⚠️ Using fallback extraction")
183
+
184
+ # Extract title (usually first line or after "Title:")
185
+ title = ""
186
+ lines = text.split('\n')
187
+ for line in lines[:20]:
188
+ if line.strip() and len(line.strip()) > 10:
189
+ title = line.strip()
190
+ break
191
+
192
+ # Extract dataset mentions
193
+ dataset = ""
194
+ dataset_patterns = [
195
+ r'(CIFAR-10|CIFAR-100|ImageNet|COCO|MNIST|Fashion-MNIST)',
196
+ r'(?:on|using|dataset)\s+(\w+)',
197
+ ]
198
+ for pattern in dataset_patterns:
199
+ match = re.search(pattern, text, re.IGNORECASE)
200
+ if match:
201
+ dataset = match.group(1)
202
+ break
203
+
204
+ # Extract model mentions
205
+ model = ""
206
+ model_patterns = [
207
+ r'(ResNet-\d+|VGG-\d+|BERT|GPT-\d+|Transformer)',
208
+ r'(AlexNet|DenseNet|MobileNet|EfficientNet)',
209
+ ]
210
+ for pattern in model_patterns:
211
+ match = re.search(pattern, text, re.IGNORECASE)
212
+ if match:
213
+ model = match.group(1)
214
+ break
215
+
216
+ # Extract metrics
217
+ metric_pattern = r'(\d+\.?\d*)\s*%?\s*(accuracy|precision|recall|F1|BLEU)'
218
+ metric_match = re.search(metric_pattern, text, re.IGNORECASE)
219
+
220
+ target_metric = 0.0
221
+ metric_name = "accuracy"
222
+
223
+ if metric_match:
224
+ target_metric = float(metric_match.group(1))
225
+ metric_name = metric_match.group(2).lower()
226
+
227
+ # Convert percentage to decimal
228
+ if target_metric > 1.0:
229
+ target_metric = target_metric / 100.0
230
+
231
+ # GitHub links
232
+ github_links = self._extract_github_links(text)
233
+
234
+ return {
235
+ 'title': title or "Unknown Paper",
236
+ 'abstract': "",
237
+ 'dataset': dataset or "Unknown",
238
+ 'model': model or "Unknown",
239
+ 'target_metric': target_metric,
240
+ 'metric_name': metric_name,
241
+ 'github_links': github_links,
242
+ 'key_claims': [],
243
+ 'confidence': 0.5
244
+ }
245
+
246
+ def parse_from_arxiv(self, arxiv_id: str) -> PaperState:
247
+ """
248
+ Parse paper from ArXiv ID.
249
+
250
+ Args:
251
+ arxiv_id: ArXiv paper ID (e.g., "2103.00020")
252
+
253
+ Returns:
254
+ PaperState
255
+ """
256
+ print(f"📄 Fetching paper from ArXiv: {arxiv_id}")
257
+
258
+ try:
259
+ import requests
260
+
261
+ # Fetch ArXiv metadata
262
+ url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
263
+ response = requests.get(url)
264
+
265
+ if response.status_code == 200:
266
+ # Parse XML response
267
+ import xml.etree.ElementTree as ET
268
+ root = ET.fromstring(response.content)
269
+
270
+ # Extract metadata
271
+ entry = root.find('{http://www.w3.org/2005/Atom}entry')
272
+
273
+ if entry:
274
+ title = entry.find('{http://www.w3.org/2005/Atom}title').text.strip()
275
+ abstract = entry.find('{http://www.w3.org/2005/Atom}summary').text.strip()
276
+
277
+ # Use LLM to extract technical details from abstract
278
+ extracted = self._extract_with_llm(f"Title: {title}\n\nAbstract: {abstract}")
279
+
280
+ return PaperState(
281
+ pdf_path=f"arxiv:{arxiv_id}",
282
+ title=title,
283
+ abstract=abstract,
284
+ dataset=extracted.get('dataset', ''),
285
+ model=extracted.get('model', ''),
286
+ target_metric=extracted.get('target_metric', 0.0),
287
+ metric_name=extracted.get('metric_name', 'accuracy'),
288
+ github_links=extracted.get('github_links', []),
289
+ key_claims=extracted.get('key_claims', []),
290
+ parsed=True,
291
+ confidence=0.7
292
+ )
293
+
294
+ except Exception as e:
295
+ print(f"❌ ArXiv fetch failed: {e}")
296
+
297
+ return PaperState(pdf_path=f"arxiv:{arxiv_id}")
298
+
299
+
300
+ # Test
301
+ if __name__ == "__main__":
302
+ from reproagent.models import LLMClient
303
+
304
+ llm = LLMClient()
305
+ parser = PaperParser(llm)
306
+
307
+ # Test with sample text
308
+ sample_text = """
309
+ Deep Residual Learning for Image Recognition
310
+
311
+ Abstract: We present a residual learning framework to ease the training of networks
312
+ that are substantially deeper than those used previously. We achieve 95.2% accuracy
313
+ on CIFAR-10 dataset using ResNet-50 architecture.
314
+
315
+ Code: https://github.com/example/resnet-cifar10
316
+ """
317
+
318
+ result = parser._extract_with_llm(sample_text)
319
+ print(result)
agents/reasoning_agent.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main reasoning agent - orchestrates the entire reproduction workflow.
3
+ Uses hypothesis-driven approach to intelligently navigate the reproduction process.
4
+ """
5
+
6
+ from typing import Dict, Any, Optional, Tuple, List
7
+ import numpy as np
8
+
9
+ from reproagent.environment import ReproAgentEnv
10
+ from reproagent.state import ReproductionState, Phase
11
+ from reproagent.actions import ActionSpace, ActionType, Action
12
+ from reproagent.models import LLMClient
13
+ from agents.paper_parser import PaperParser
14
+ from agents.repo_analyzer import RepoAnalyzer
15
+ from agents.debugger import Debugger
16
+
17
+
18
+ class ReasoningAgent:
19
+ """
20
+ Main intelligent agent for paper reproduction.
21
+
22
+ Strategy:
23
+ 1. Parse paper → understand what to reproduce
24
+ 2. Find & analyze repo → understand how to reproduce
25
+ 3. Setup environment → prepare for execution
26
+ 4. Execute & debug → run code, fix errors
27
+ 5. Experiment → tune hyperparameters
28
+ 6. Compare → validate reproduction
29
+ """
30
+
31
+ def __init__(self, env: ReproAgentEnv, use_llm: bool = True):
32
+ """
33
+ Args:
34
+ env: ReproAgent environment
35
+ use_llm: Whether to use LLM for reasoning
36
+ """
37
+ self.env = env
38
+ self.action_space = ActionSpace()
39
+ self.use_llm = use_llm
40
+
41
+ # Initialize LLM and sub-agents
42
+ if use_llm:
43
+ try:
44
+ self.llm = LLMClient()
45
+ except:
46
+ print("⚠️ LLM not available, using rule-based mode")
47
+ self.llm = LLMClient(provider="mock")
48
+ self.use_llm = False
49
+ else:
50
+ self.llm = LLMClient(provider="mock")
51
+
52
+ self.paper_parser = PaperParser(self.llm)
53
+ self.repo_analyzer = RepoAnalyzer(self.llm)
54
+ self.debugger = Debugger(self.llm)
55
+
56
+ # Agent state
57
+ self.current_strategy = "systematic" # systematic, debugging, experimenting
58
+ self.hypotheses = []
59
+ self.phase_progress = {
60
+ Phase.PARSING: False,
61
+ Phase.REPO_ANALYSIS: False,
62
+ Phase.SETUP: False,
63
+ Phase.EXECUTION: False,
64
+ Phase.DEBUGGING: False,
65
+ Phase.EXPERIMENTATION: False,
66
+ }
67
+
68
+ def select_action(
69
+ self,
70
+ observation: Dict[str, np.ndarray],
71
+ info: Dict[str, Any]
72
+ ) -> int:
73
+ """
74
+ Select next action based on current state.
75
+
76
+ Args:
77
+ observation: Environment observation
78
+ info: Additional info
79
+
80
+ Returns:
81
+ Action ID
82
+ """
83
+ # Get current state from environment
84
+ state = self.env.state
85
+
86
+ # Determine strategy based on phase
87
+ if state.meta.phase == Phase.IDLE or state.meta.phase == Phase.PARSING:
88
+ return self._parsing_phase_action(state)
89
+
90
+ elif state.meta.phase == Phase.REPO_ANALYSIS:
91
+ return self._repo_analysis_action(state)
92
+
93
+ elif state.meta.phase == Phase.SETUP:
94
+ return self._setup_phase_action(state)
95
+
96
+ elif state.meta.phase == Phase.EXECUTION:
97
+ return self._execution_phase_action(state)
98
+
99
+ elif state.meta.phase == Phase.DEBUGGING:
100
+ return self._debugging_phase_action(state)
101
+
102
+ elif state.meta.phase == Phase.EXPERIMENTATION:
103
+ return self._experimentation_action(state)
104
+
105
+ elif state.meta.phase == Phase.COMPARISON:
106
+ if not getattr(state.meta, 'report_generated', False):
107
+ return self.action_space.get_id_by_action(ActionType.GENERATE_REPORT)
108
+ else:
109
+ return self.action_space.get_id_by_action(ActionType.STOP_PROCESS)
110
+
111
+ else:
112
+ # Default: random exploration
113
+ return self.env.action_space.sample()
114
+
115
+ def _parsing_phase_action(self, state: ReproductionState) -> int:
116
+ """Actions for paper parsing phase."""
117
+
118
+ if not state.paper.parsed:
119
+ return self.action_space.get_id_by_action(ActionType.PARSE_PDF)
120
+
121
+ elif not state.paper.github_links:
122
+ return self.action_space.get_id_by_action(ActionType.EXTRACT_GITHUB)
123
+
124
+ else:
125
+ # Parsing is complete — move to repo cloning
126
+ if not state.repo.cloned:
127
+ return self.action_space.get_id_by_action(ActionType.CLONE_REPO)
128
+ else:
129
+ return self.action_space.get_id_by_action(ActionType.READ_README)
130
+
131
+ def _repo_analysis_action(self, state: ReproductionState) -> int:
132
+ """Actions for repository analysis phase."""
133
+
134
+ if not state.repo.cloned and state.paper.github_links:
135
+ return self.action_space.get_id_by_action(ActionType.CLONE_REPO)
136
+
137
+ elif state.repo.cloned and not state.repo.readme_content:
138
+ return self.action_space.get_id_by_action(ActionType.READ_README)
139
+
140
+ elif state.repo.readme_content and not state.repo.entry_point:
141
+ return self.action_space.get_id_by_action(ActionType.FIND_ENTRY_POINT)
142
+
143
+ elif state.repo.entry_point and not state.repo.dependencies:
144
+ return self.action_space.get_id_by_action(ActionType.EXTRACT_DEPS)
145
+
146
+ else:
147
+ # Repo fully analyzed — move to environment setup (CREATE_VENV first!)
148
+ return self.action_space.get_id_by_action(ActionType.CREATE_VENV)
149
+
150
+ def _setup_phase_action(self, state: ReproductionState) -> int:
151
+ """Actions for environment setup phase."""
152
+
153
+ if not state.environment.setup_complete:
154
+ if state.repo.dependencies:
155
+ return self.action_space.get_id_by_action(ActionType.INSTALL_REQUIREMENTS)
156
+ else:
157
+ # Even with no explicit deps listed, verify setup
158
+ return self.action_space.get_id_by_action(ActionType.VERIFY_SETUP)
159
+
160
+ else:
161
+ # Setup complete — move to execution
162
+ return self.action_space.get_id_by_action(ActionType.RUN_TRAINING)
163
+
164
+ def _execution_phase_action(self, state: ReproductionState) -> int:
165
+ """Actions for code execution phase."""
166
+
167
+ if state.execution.last_error:
168
+ # Transition to debugging
169
+ return self.action_space.get_id_by_action(ActionType.ANALYZE_ERROR)
170
+
171
+ elif state.experiment.current_metric > 0 and state.experiment.gap > 0.05:
172
+ # Has some results but gap is large — move to experimentation
173
+ return self.action_space.get_id_by_action(ActionType.RUN_EXPERIMENT)
174
+
175
+ elif state.experiment.current_metric > 0 and state.experiment.gap <= 0.05:
176
+ # Close enough — compare
177
+ return self.action_space.get_id_by_action(ActionType.COMPARE_RESULTS)
178
+
179
+ else:
180
+ # Run training
181
+ return self.action_space.get_id_by_action(ActionType.RUN_TRAINING)
182
+
183
+ def _debugging_phase_action(self, state: ReproductionState) -> int:
184
+ """Actions for debugging phase."""
185
+
186
+ total_debug_actions = len(state.debug.fix_attempts) + len(state.debug.solutions_tried)
187
+
188
+ # Cap: after 3 debug attempts, give up and compare what we have
189
+ if total_debug_actions >= 3:
190
+ state.debug.current_error = "" # clear to break loop
191
+ return self.action_space.get_id_by_action(ActionType.COMPARE_RESULTS)
192
+
193
+ if state.debug.current_error and not state.debug.last_hypothesis:
194
+ return self.action_space.get_id_by_action(ActionType.ANALYZE_ERROR)
195
+
196
+ elif state.debug.last_hypothesis and len(state.debug.fix_attempts) == 0:
197
+ return self.action_space.get_id_by_action(ActionType.APPLY_FIX)
198
+
199
+ elif state.debug.current_error:
200
+ return self.action_space.get_id_by_action(ActionType.APPLY_FIX)
201
+
202
+ else:
203
+ # Error resolved — back to execution
204
+ return self.action_space.get_id_by_action(ActionType.RUN_TRAINING)
205
+
206
+ def _experimentation_action(self, state: ReproductionState) -> int:
207
+ """Actions for hyperparameter tuning phase."""
208
+
209
+ gap = state.experiment.gap
210
+ experiments_run = state.experiment.experiments_run
211
+
212
+ # Use LLM for intelligent hyperparameter selection if available
213
+ if self.use_llm and experiments_run > 0:
214
+ action = self._llm_suggest_hyperparameter_action(state)
215
+ if action is not None:
216
+ return action
217
+
218
+ # Rule-based: alternate between tuning a param and running an experiment
219
+ if experiments_run > 0 and experiments_run % 2 == 0:
220
+ # Every other step, run an experiment to measure progress
221
+ return self.action_space.get_id_by_action(ActionType.RUN_EXPERIMENT)
222
+
223
+ if gap > 0.3:
224
+ return self.action_space.get_id_by_action(ActionType.MODIFY_LR)
225
+ elif gap > 0.15:
226
+ if experiments_run % 4 < 2:
227
+ return self.action_space.get_id_by_action(ActionType.MODIFY_BATCH)
228
+ else:
229
+ return self.action_space.get_id_by_action(ActionType.MODIFY_OPTIMIZER)
230
+ elif gap > 0.05:
231
+ return self.action_space.get_id_by_action(ActionType.ADD_REGULARIZATION)
232
+ else:
233
+ # Very close — run experiment to lock in
234
+ return self.action_space.get_id_by_action(ActionType.RUN_EXPERIMENT)
235
+
236
+ def _llm_suggest_hyperparameter_action(self, state: ReproductionState) -> Optional[int]:
237
+ """Use LLM to suggest next hyperparameter action."""
238
+
239
+ prompt = f"""
240
+ You are tuning hyperparameters to reproduce a paper's results.
241
+
242
+ Current state:
243
+ - Target metric: {state.paper.target_metric:.3f}
244
+ - Current metric: {state.experiment.current_metric:.3f}
245
+ - Gap: {state.experiment.gap:.3f}
246
+ - Experiments run: {state.experiment.experiments_run}
247
+ - Current config: {state.experiment.current_config}
248
+
249
+ What should be adjusted next?
250
+
251
+ Options:
252
+ 1. learning_rate
253
+ 2. batch_size
254
+ 3. optimizer
255
+ 4. epochs
256
+ 5. regularization
257
+ 6. run_experiment (test current config)
258
+
259
+ Respond with JSON:
260
+ {{
261
+ "action": "learning_rate",
262
+ "reasoning": "why this action"
263
+ }}
264
+ """
265
+
266
+ try:
267
+ result = self.llm.generate_structured(prompt)
268
+ action_name = result.get('action', '')
269
+
270
+ action_map = {
271
+ 'learning_rate': ActionType.MODIFY_LR,
272
+ 'batch_size': ActionType.MODIFY_BATCH,
273
+ 'optimizer': ActionType.MODIFY_OPTIMIZER,
274
+ 'epochs': ActionType.MODIFY_EPOCHS,
275
+ 'regularization': ActionType.ADD_REGULARIZATION,
276
+ 'run_experiment': ActionType.RUN_EXPERIMENT
277
+ }
278
+
279
+ if action_name in action_map:
280
+ action_type = action_map[action_name]
281
+ return self.action_space.get_id_by_action(action_type)
282
+
283
+ except Exception as e:
284
+ print(f"⚠️ LLM suggestion failed: {e}")
285
+
286
+ return None
287
+
288
+ def form_hypothesis(self, state: ReproductionState) -> str:
289
+ """
290
+ Form hypothesis about what's preventing reproduction.
291
+
292
+ Args:
293
+ state: Current state
294
+
295
+ Returns:
296
+ Hypothesis string
297
+ """
298
+ if not state.paper.parsed:
299
+ return "Need to parse paper to understand target"
300
+
301
+ elif not state.repo.cloned:
302
+ return "Need to find and clone repository"
303
+
304
+ elif state.debug.current_error:
305
+ return f"Need to fix error: {state.debug.current_error[:50]}"
306
+
307
+ elif state.experiment.gap > 0.2:
308
+ return "Hyperparameters are significantly off from optimal"
309
+
310
+ elif state.experiment.gap > 0.05:
311
+ return "Need fine-tuning of hyperparameters"
312
+
313
+ else:
314
+ return "Close to target, validating reproduction"
315
+
316
+ def get_reasoning(self, state: ReproductionState, action_id: int) -> str:
317
+ """
318
+ Generate human-readable reasoning for action.
319
+
320
+ Args:
321
+ state: Current state
322
+ action_id: Selected action
323
+
324
+ Returns:
325
+ Reasoning string
326
+ """
327
+ action_type = self.action_space.get_action_by_id(action_id)
328
+
329
+ reasoning_map = {
330
+ ActionType.PARSE_PDF: f"📄 Parsing paper to extract methodology",
331
+ ActionType.EXTRACT_GITHUB: f"🔍 Looking for implementation repository",
332
+ ActionType.CLONE_REPO: f"📥 Cloning repository: {state.paper.github_links[0] if state.paper.github_links else 'unknown'}",
333
+ ActionType.READ_README: f"📖 Reading setup instructions",
334
+ ActionType.INSTALL_REQUIREMENTS: f"📦 Installing {len(state.repo.dependencies)} dependencies",
335
+ ActionType.RUN_TRAINING: f"🚀 Executing training script",
336
+ ActionType.ANALYZE_ERROR: f"🔍 Analyzing error: {state.debug.current_error[:30]}...",
337
+ ActionType.APPLY_FIX: f"🔧 Applying fix attempt #{len(state.debug.fix_attempts) + 1}",
338
+ ActionType.RUN_EXPERIMENT: f"🧪 Running experiment #{state.experiment.experiments_run + 1}",
339
+ ActionType.MODIFY_LR: f"⚙️ Adjusting learning rate (gap: {state.experiment.gap:.3f})",
340
+ ActionType.COMPARE_RESULTS: f"📊 Comparing results: {state.experiment.current_metric:.3f} vs {state.paper.target_metric:.3f}",
341
+ }
342
+
343
+ return reasoning_map.get(action_type, f"Executing {action_type.value}")
344
+
345
+ def reset(self):
346
+ """Reset agent for new episode."""
347
+ self.current_strategy = "systematic"
348
+ self.hypotheses = []
349
+ self.phase_progress = {phase: False for phase in Phase}
350
+
351
+ def get_stats(self) -> Dict[str, Any]:
352
+ """Get agent statistics."""
353
+ return {
354
+ 'strategy': self.current_strategy,
355
+ 'hypotheses_formed': len(self.hypotheses),
356
+ 'phases_completed': sum(self.phase_progress.values())
357
+ }
358
+
359
+
360
+ class RLAgent:
361
+ """
362
+ RL-trainable agent (for PPO/DPO training).
363
+ Uses neural network policy.
364
+ """
365
+
366
+ def __init__(self, env: ReproAgentEnv, policy_network=None):
367
+ """
368
+ Args:
369
+ env: Environment
370
+ policy_network: Pre-trained policy (optional)
371
+ """
372
+ self.env = env
373
+ self.policy = policy_network
374
+
375
+ if policy_network is None:
376
+ self._init_policy()
377
+
378
+ def _init_policy(self):
379
+ """Initialize policy network."""
380
+ try:
381
+ import torch
382
+ import torch.nn as nn
383
+
384
+ # Simple MLP policy
385
+ obs_dim = 25 # 5 feature vectors × 5 dims each
386
+ action_dim = self.env.action_space.n
387
+
388
+ self.policy = nn.Sequential(
389
+ nn.Linear(obs_dim, 128),
390
+ nn.ReLU(),
391
+ nn.Linear(128, 128),
392
+ nn.ReLU(),
393
+ nn.Linear(128, action_dim),
394
+ nn.Softmax(dim=-1)
395
+ )
396
+ except ImportError:
397
+ print("⚠️ PyTorch not installed, using random policy")
398
+ self.policy = None
399
+
400
+ def select_action(
401
+ self,
402
+ observation: Dict[str, np.ndarray],
403
+ info: Dict[str, Any]
404
+ ) -> int:
405
+ """Select action using policy network."""
406
+
407
+ if self.policy is None:
408
+ return self.env.action_space.sample()
409
+
410
+ try:
411
+ import torch
412
+
413
+ # Flatten observation
414
+ obs_vec = np.concatenate([
415
+ observation['paper_features'],
416
+ observation['repo_features'],
417
+ observation['execution_features'],
418
+ observation['experiment_features'],
419
+ observation['meta_features']
420
+ ])
421
+
422
+ obs_tensor = torch.FloatTensor(obs_vec).unsqueeze(0)
423
+
424
+ with torch.no_grad():
425
+ action_probs = self.policy(obs_tensor)
426
+
427
+ # Sample action
428
+ action = torch.multinomial(action_probs, 1).item()
429
+
430
+ return action
431
+ except:
432
+ return self.env.action_space.sample()
433
+
434
+ def reset(self):
435
+ """Reset agent."""
436
+ pass
437
+
438
+ def get_stats(self) -> Dict[str, Any]:
439
+ """Get stats."""
440
+ return {'type': 'RL'}
441
+
442
+
443
+ # Factory function
444
+ def create_agent(env: ReproAgentEnv, agent_type: str = "reasoning", **kwargs):
445
+ """
446
+ Factory function to create agents.
447
+
448
+ Args:
449
+ env: Environment
450
+ agent_type: 'reasoning', 'rl', or 'random'
451
+ **kwargs: Additional arguments
452
+
453
+ Returns:
454
+ Agent instance
455
+ """
456
+ if agent_type == "reasoning":
457
+ return ReasoningAgent(env, use_llm=kwargs.get('use_llm', True))
458
+
459
+ elif agent_type == "rl":
460
+ return RLAgent(env, policy_network=kwargs.get('policy', None))
461
+
462
+ elif agent_type == "random":
463
+ # Simple random agent for baseline
464
+ class RandomAgent:
465
+ def __init__(self, env):
466
+ self.env = env
467
+
468
+ def select_action(self, obs, info):
469
+ return self.env.action_space.sample()
470
+
471
+ def reset(self):
472
+ pass
473
+
474
+ def get_stats(self):
475
+ return {'type': 'random'}
476
+
477
+ def get_reasoning(self, state, action_id):
478
+ return f"Random action: {action_id}"
479
+
480
+ return RandomAgent(env)
481
+
482
+ else:
483
+ raise ValueError(f"Unknown agent type: {agent_type}")
484
+
485
+
486
+ # Test
487
+ if __name__ == "__main__":
488
+ from reproagent.environment import ReproAgentEnv
489
+
490
+ # Create environment
491
+ env = ReproAgentEnv(difficulty="easy", use_llm=False)
492
+
493
+ # Create agent
494
+ agent = create_agent(env, agent_type="reasoning", use_llm=False)
495
+
496
+ # Run episode
497
+ obs, info = env.reset()
498
+
499
+ for step in range(20):
500
+ action = agent.select_action(obs, info)
501
+ obs, reward, terminated, truncated, info = env.step(action)
502
+
503
+ print(f"Step {step + 1}: {info.get('action_type', 'unknown')} | Reward: {reward:.2f}")
504
+
505
+ if terminated or truncated:
506
+ break
507
+
508
+ print(f"\nFinal metric: {info.get('current_metric', 0.0):.3f}")
agents/repo_analyzer.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Repository analyzer - analyzes GitHub repositories.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ from typing import Dict, Any, List, Optional
8
+ from pathlib import Path
9
+ import subprocess
10
+
11
+ from reproagent.models import LLMClient
12
+ from reproagent.state import RepoState
13
+
14
+
15
+ class RepoAnalyzer:
16
+ """
17
+ Analyzes GitHub repositories to understand:
18
+ - Code structure
19
+ - Dependencies
20
+ - Entry points
21
+ - Setup instructions
22
+ """
23
+
24
+ def __init__(self, llm_client: LLMClient):
25
+ """
26
+ Args:
27
+ llm_client: LLM for code analysis
28
+ """
29
+ self.llm = llm_client
30
+
31
+ def analyze_repo(self, repo_url: str, local_path: Optional[str] = None) -> RepoState:
32
+ """
33
+ Analyze a GitHub repository.
34
+
35
+ Args:
36
+ repo_url: GitHub URL
37
+ local_path: Local path (if already cloned)
38
+
39
+ Returns:
40
+ RepoState with analysis
41
+ """
42
+ print(f"📦 Analyzing repository: {repo_url}")
43
+
44
+ # Clone if needed
45
+ if not local_path:
46
+ local_path = self._clone_repo(repo_url)
47
+
48
+ if not local_path or not Path(local_path).exists():
49
+ print(f"❌ Failed to access repository")
50
+ return RepoState(url=repo_url)
51
+
52
+ # Analyze components
53
+ readme_content = self._read_readme(local_path)
54
+ dependencies = self._extract_dependencies(local_path)
55
+ entry_point = self._find_entry_point(local_path)
56
+ framework = self._detect_framework(local_path, dependencies)
57
+ setup_instructions = self._extract_setup_instructions(readme_content)
58
+
59
+ state = RepoState(
60
+ url=repo_url,
61
+ cloned=True,
62
+ local_path=local_path,
63
+ readme_content=readme_content,
64
+ setup_instructions=setup_instructions,
65
+ dependencies=dependencies,
66
+ entry_point=entry_point,
67
+ framework=framework,
68
+ repo_quality_score=self._calculate_quality_score(local_path, readme_content)
69
+ )
70
+
71
+ print(f"✅ Repository analyzed")
72
+ print(f" Framework: {state.framework}")
73
+ print(f" Entry point: {state.entry_point}")
74
+ print(f" Dependencies: {len(state.dependencies)}")
75
+
76
+ return state
77
+
78
+ def _clone_repo(self, repo_url: str) -> Optional[str]:
79
+ """
80
+ Clone GitHub repository.
81
+
82
+ Args:
83
+ repo_url: GitHub URL
84
+
85
+ Returns:
86
+ Local path or None if failed
87
+ """
88
+ try:
89
+ # Create temp directory
90
+ import tempfile
91
+ temp_dir = tempfile.mkdtemp(prefix="reproagent_")
92
+
93
+ print(f"📥 Cloning to {temp_dir}...")
94
+
95
+ # Clone with git
96
+ result = subprocess.run(
97
+ ['git', 'clone', '--depth', '1', repo_url, temp_dir],
98
+ capture_output=True,
99
+ text=True,
100
+ timeout=60
101
+ )
102
+
103
+ if result.returncode == 0:
104
+ print(f"✅ Repository cloned")
105
+ return temp_dir
106
+ else:
107
+ print(f"❌ Clone failed: {result.stderr}")
108
+ return None
109
+
110
+ except Exception as e:
111
+ print(f"❌ Clone error: {e}")
112
+ return None
113
+
114
+ def _read_readme(self, repo_path: str) -> str:
115
+ """Read README file."""
116
+ readme_files = ['README.md', 'README.rst', 'README.txt', 'README']
117
+
118
+ for readme_name in readme_files:
119
+ readme_path = Path(repo_path) / readme_name
120
+ if readme_path.exists():
121
+ try:
122
+ with open(readme_path, 'r', encoding='utf-8') as f:
123
+ return f.read()
124
+ except Exception as e:
125
+ print(f"⚠️ Error reading {readme_name}: {e}")
126
+
127
+ return ""
128
+
129
+ def _extract_dependencies(self, repo_path: str) -> List[str]:
130
+ """Extract dependencies from requirements files."""
131
+ dependencies = []
132
+
133
+ # Check requirements.txt
134
+ req_path = Path(repo_path) / 'requirements.txt'
135
+ if req_path.exists():
136
+ try:
137
+ with open(req_path, 'r') as f:
138
+ for line in f:
139
+ line = line.strip()
140
+ if line and not line.startswith('#'):
141
+ # Extract package name (before ==, >=, etc.)
142
+ pkg = re.split(r'[=<>!]', line)[0].strip()
143
+ dependencies.append(pkg)
144
+ except Exception as e:
145
+ print(f"⚠️ Error reading requirements.txt: {e}")
146
+
147
+ # Check setup.py
148
+ setup_path = Path(repo_path) / 'setup.py'
149
+ if setup_path.exists():
150
+ try:
151
+ with open(setup_path, 'r') as f:
152
+ content = f.read()
153
+ # Look for install_requires
154
+ match = re.search(r'install_requires\s*=\s*\[(.*?)\]', content, re.DOTALL)
155
+ if match:
156
+ deps_str = match.group(1)
157
+ for dep in re.findall(r'["\']([^"\']+)["\']', deps_str):
158
+ pkg = re.split(r'[=<>!]', dep)[0].strip()
159
+ if pkg not in dependencies:
160
+ dependencies.append(pkg)
161
+ except Exception as e:
162
+ print(f"⚠️ Error reading setup.py: {e}")
163
+
164
+ # Check pyproject.toml
165
+ pyproject_path = Path(repo_path) / 'pyproject.toml'
166
+ if pyproject_path.exists():
167
+ try:
168
+ import tomli
169
+ with open(pyproject_path, 'rb') as f:
170
+ data = tomli.load(f)
171
+ deps = data.get('project', {}).get('dependencies', [])
172
+ for dep in deps:
173
+ pkg = re.split(r'[=<>!]', dep)[0].strip()
174
+ if pkg not in dependencies:
175
+ dependencies.append(pkg)
176
+ except:
177
+ pass
178
+
179
+ return dependencies
180
+
181
+ def _find_entry_point(self, repo_path: str) -> str:
182
+ """Find main entry point script."""
183
+ # Common entry point names
184
+ candidates = [
185
+ 'train.py',
186
+ 'main.py',
187
+ 'run.py',
188
+ 'train_model.py',
189
+ 'finetune.py',
190
+ 'run_training.py'
191
+ ]
192
+
193
+ repo_dir = Path(repo_path)
194
+
195
+ for candidate in candidates:
196
+ if (repo_dir / candidate).exists():
197
+ return candidate
198
+
199
+ # Search in subdirectories
200
+ for py_file in repo_dir.rglob('*.py'):
201
+ if py_file.stem in ['train', 'main', 'run']:
202
+ return str(py_file.relative_to(repo_dir))
203
+
204
+ return ""
205
+
206
+ def _detect_framework(self, repo_path: str, dependencies: List[str]) -> str:
207
+ """Detect ML framework used."""
208
+ dep_str = ' '.join(dependencies).lower()
209
+
210
+ if 'torch' in dep_str or 'pytorch' in dep_str:
211
+ return 'pytorch'
212
+ elif 'tensorflow' in dep_str or 'tf' in dep_str:
213
+ return 'tensorflow'
214
+ elif 'jax' in dep_str:
215
+ return 'jax'
216
+ elif 'keras' in dep_str:
217
+ return 'keras'
218
+
219
+ # Check imports in Python files
220
+ try:
221
+ for py_file in Path(repo_path).rglob('*.py'):
222
+ with open(py_file, 'r') as f:
223
+ content = f.read(1000) # First 1000 chars
224
+ if 'import torch' in content:
225
+ return 'pytorch'
226
+ elif 'import tensorflow' in content:
227
+ return 'tensorflow'
228
+ except:
229
+ pass
230
+
231
+ return "unknown"
232
+
233
+ def _extract_setup_instructions(self, readme_content: str) -> List[str]:
234
+ """
235
+ Extract setup instructions from README using LLM.
236
+
237
+ Args:
238
+ readme_content: README text
239
+
240
+ Returns:
241
+ List of setup steps
242
+ """
243
+ if not readme_content:
244
+ return []
245
+
246
+ # Truncate README
247
+ readme_sample = readme_content[:3000]
248
+
249
+ prompt = f"""
250
+ Extract step-by-step setup/installation instructions from this README.
251
+
252
+ README:
253
+ {readme_sample}
254
+
255
+ Respond with JSON:
256
+ {{
257
+ "setup_steps": ["step 1", "step 2", ...]
258
+ }}
259
+ """
260
+
261
+ try:
262
+ result = self.llm.generate_structured(prompt)
263
+ return result.get('setup_steps', [])
264
+ except:
265
+ # Fallback: simple extraction
266
+ return self._simple_setup_extraction(readme_content)
267
+
268
+ def _simple_setup_extraction(self, readme: str) -> List[str]:
269
+ """Simple regex-based setup extraction."""
270
+ steps = []
271
+
272
+ # Look for pip install commands
273
+ pip_pattern = r'pip install (.+)'
274
+ for match in re.finditer(pip_pattern, readme):
275
+ steps.append(f"pip install {match.group(1).strip()}")
276
+
277
+ # Look for numbered steps
278
+ step_pattern = r'^\d+\.\s+(.+)$'
279
+ for line in readme.split('\n'):
280
+ match = re.match(step_pattern, line.strip())
281
+ if match:
282
+ steps.append(match.group(1))
283
+
284
+ return steps[:10] # Max 10 steps
285
+
286
+ def _calculate_quality_score(self, repo_path: str, readme: str) -> float:
287
+ """
288
+ Calculate repository quality score.
289
+
290
+ Factors:
291
+ - Has README
292
+ - Has requirements/setup files
293
+ - Has tests
294
+ - Code organization
295
+ """
296
+ score = 0.0
297
+
298
+ # Has README (0.3)
299
+ if readme:
300
+ score += 0.3
301
+
302
+ # Has requirements (0.2)
303
+ if (Path(repo_path) / 'requirements.txt').exists():
304
+ score += 0.2
305
+
306
+ # Has setup.py or pyproject.toml (0.2)
307
+ if (Path(repo_path) / 'setup.py').exists() or (Path(repo_path) / 'pyproject.toml').exists():
308
+ score += 0.2
309
+
310
+ # Has tests (0.15)
311
+ if (Path(repo_path) / 'tests').exists() or (Path(repo_path) / 'test').exists():
312
+ score += 0.15
313
+
314
+ # Has LICENSE (0.05)
315
+ if (Path(repo_path) / 'LICENSE').exists():
316
+ score += 0.05
317
+
318
+ # Has .gitignore (0.05)
319
+ if (Path(repo_path) / '.gitignore').exists():
320
+ score += 0.05
321
+
322
+ # Good README length (0.05)
323
+ if len(readme) > 500:
324
+ score += 0.05
325
+
326
+ return min(1.0, score)
327
+
328
+
329
+ # Test
330
+ if __name__ == "__main__":
331
+ from reproagent.models import LLMClient
332
+
333
+ llm = LLMClient()
334
+ analyzer = RepoAnalyzer(llm)
335
+
336
+ # Test with a real repo
337
+ state = analyzer.analyze_repo("https://github.com/pytorch/examples")
338
+ print(state.to_dict())
assets/loss_plot.png ADDED
assets/reward_plot.png ADDED