mo35 commited on
Commit
2137240
·
verified ·
1 Parent(s): e7dffe0

Upload server/code_refactor_gym_environment.py with huggingface_hub

Browse files
server/code_refactor_gym_environment.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Code Refactor Gym Environment Implementation.
9
+
10
+ An environment that teaches agents to refactor legacy code into modern,
11
+ maintainable code with improved quality metrics.
12
+ """
13
+
14
+ import ast
15
+ import random
16
+ from typing import Dict, Any
17
+ from uuid import uuid4
18
+
19
+ from openenv.core.env_server.interfaces import Environment
20
+ from openenv.core.env_server.types import State
21
+
22
+ from models import CodeRefactorGymAction, CodeRefactorGymObservation
23
+
24
+
25
+ # Legacy code samples for refactoring
26
+ LEGACY_CODE_SAMPLES = [
27
+ # Sample 1: Poor naming, no type hints, complex logic
28
+ """
29
+ def f(x, y):
30
+ result = []
31
+ for i in range(len(x)):
32
+ if x[i] > y:
33
+ result.append(x[i])
34
+ return result
35
+ """,
36
+ # Sample 2: Global variables, poor structure
37
+ """
38
+ total = 0
39
+ def add(x):
40
+ global total
41
+ total = total + x
42
+ return total
43
+ """,
44
+ # Sample 3: Nested conditions, poor readability
45
+ """
46
+ def check(data):
47
+ if data != None:
48
+ if len(data) > 0:
49
+ if type(data) == list:
50
+ return True
51
+ return False
52
+ """,
53
+ # Sample 4: No error handling, magic numbers
54
+ """
55
+ def process(items):
56
+ result = []
57
+ for i in items:
58
+ if i % 2 == 0:
59
+ result.append(i * 3.14159)
60
+ return result
61
+ """,
62
+ # Sample 5: Repetitive code, no abstraction
63
+ """
64
+ def calc1(x):
65
+ return x * 2 + 10
66
+
67
+ def calc2(x):
68
+ return x * 3 + 10
69
+
70
+ def calc3(x):
71
+ return x * 4 + 10
72
+ """,
73
+ ]
74
+
75
+
76
+ class CodeRefactorGymEnvironment(Environment):
77
+ """
78
+ Environment for learning code refactoring.
79
+
80
+ The agent receives legacy code and must refactor it to improve:
81
+ - Code readability (naming, structure)
82
+ - Type safety (type hints)
83
+ - Best practices (avoiding globals, proper error handling)
84
+ - Code metrics (complexity, maintainability)
85
+
86
+ Rewards are based on improvement in code quality metrics and syntax validity.
87
+ """
88
+
89
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
90
+
91
+ def __init__(self):
92
+ """Initialize the code_refactor_gym environment."""
93
+ self._state = State(episode_id=str(uuid4()), step_count=0)
94
+ self._current_legacy_code = ""
95
+ self._baseline_metrics = {}
96
+
97
+ def reset(self) -> CodeRefactorGymObservation:
98
+ """
99
+ Reset the environment with a new legacy code sample.
100
+
101
+ Returns:
102
+ CodeRefactorGymObservation with the legacy code to refactor
103
+ """
104
+ self._state = State(episode_id=str(uuid4()), step_count=0)
105
+ self._current_legacy_code = random.choice(LEGACY_CODE_SAMPLES)
106
+ self._baseline_metrics = self._calculate_metrics(self._current_legacy_code)
107
+
108
+ return CodeRefactorGymObservation(
109
+ legacy_code=self._current_legacy_code,
110
+ test_results={},
111
+ quality_metrics=self._baseline_metrics,
112
+ syntax_valid=True,
113
+ error_message="",
114
+ improvement_score=0.0,
115
+ done=False,
116
+ reward=0.0,
117
+ )
118
+
119
+ def step(self, action: CodeRefactorGymAction) -> CodeRefactorGymObservation: # type: ignore[override]
120
+ """
121
+ Evaluate the refactored code.
122
+
123
+ Args:
124
+ action: CodeRefactorGymAction containing the refactored code
125
+
126
+ Returns:
127
+ CodeRefactorGymObservation with evaluation results
128
+ """
129
+ self._state.step_count += 1
130
+ refactored_code = action.refactored_code
131
+
132
+ # Check syntax validity
133
+ syntax_valid, error_message = self._check_syntax(refactored_code)
134
+
135
+ if not syntax_valid:
136
+ return CodeRefactorGymObservation(
137
+ legacy_code=self._current_legacy_code,
138
+ test_results={"syntax_check": "failed"},
139
+ quality_metrics={},
140
+ syntax_valid=False,
141
+ error_message=error_message,
142
+ improvement_score=0.0,
143
+ done=False,
144
+ reward=-10.0, # Penalty for syntax errors
145
+ metadata={"step": self._state.step_count, "reasoning": action.reasoning},
146
+ )
147
+
148
+ # Calculate quality metrics
149
+ new_metrics = self._calculate_metrics(refactored_code)
150
+ improvement_score = self._calculate_improvement(self._baseline_metrics, new_metrics)
151
+
152
+ # Calculate reward based on improvement
153
+ reward = improvement_score / 10.0 # Scale to reasonable range
154
+
155
+ # Bonus for significant improvements
156
+ if improvement_score > 70:
157
+ reward += 5.0
158
+
159
+ # Episode ends after one refactoring attempt
160
+ done = True
161
+
162
+ return CodeRefactorGymObservation(
163
+ legacy_code=self._current_legacy_code,
164
+ test_results={"syntax_check": "passed", "metrics_improved": improvement_score > 0},
165
+ quality_metrics=new_metrics,
166
+ syntax_valid=True,
167
+ error_message="",
168
+ improvement_score=improvement_score,
169
+ done=done,
170
+ reward=reward,
171
+ metadata={
172
+ "step": self._state.step_count,
173
+ "reasoning": action.reasoning,
174
+ "baseline_metrics": self._baseline_metrics,
175
+ "improvement_details": {
176
+ "lines_change": new_metrics.get("lines", 0) - self._baseline_metrics.get("lines", 0),
177
+ "complexity_change": new_metrics.get("complexity", 0) - self._baseline_metrics.get("complexity", 0),
178
+ },
179
+ },
180
+ )
181
+
182
+ def _check_syntax(self, code: str) -> tuple[bool, str]:
183
+ """Check if the code has valid Python syntax."""
184
+ try:
185
+ ast.parse(code)
186
+ return True, ""
187
+ except SyntaxError as e:
188
+ return False, f"Syntax error at line {e.lineno}: {e.msg}"
189
+ except Exception as e:
190
+ return False, f"Parse error: {str(e)}"
191
+
192
+ def _calculate_metrics(self, code: str) -> Dict[str, Any]:
193
+ """
194
+ Calculate code quality metrics.
195
+
196
+ Metrics include:
197
+ - lines: Number of non-empty lines
198
+ - complexity: Cyclomatic complexity estimate
199
+ - has_type_hints: Whether type hints are present
200
+ - has_docstring: Whether docstring is present
201
+ - avg_line_length: Average line length
202
+ """
203
+ lines = [line for line in code.strip().split('\n') if line.strip()]
204
+ num_lines = len(lines)
205
+
206
+ # Simple complexity estimate: count control flow statements
207
+ complexity = code.count('if ') + code.count('for ') + code.count('while ') + code.count('except')
208
+
209
+ # Check for type hints
210
+ has_type_hints = '->' in code or ': ' in code
211
+
212
+ # Check for docstring
213
+ has_docstring = '"""' in code or "'''" in code
214
+
215
+ # Average line length
216
+ avg_line_length = sum(len(line) for line in lines) / max(num_lines, 1)
217
+
218
+ # Check for bad patterns
219
+ has_globals = 'global ' in code
220
+ has_magic_numbers = any(c.isdigit() for c in code if c not in ['0', '1'])
221
+
222
+ return {
223
+ "lines": num_lines,
224
+ "complexity": complexity,
225
+ "has_type_hints": has_type_hints,
226
+ "has_docstring": has_docstring,
227
+ "avg_line_length": avg_line_length,
228
+ "has_globals": has_globals,
229
+ "has_magic_numbers": has_magic_numbers,
230
+ }
231
+
232
+ def _calculate_improvement(self, baseline: Dict[str, Any], new: Dict[str, Any]) -> float:
233
+ """
234
+ Calculate improvement score (0-100) based on metric changes.
235
+
236
+ Higher score = better refactoring.
237
+ """
238
+ score = 50.0 # Start at neutral
239
+
240
+ # Penalize if code gets longer (should be more concise)
241
+ if new.get("lines", 0) > baseline.get("lines", 0):
242
+ score -= 5
243
+ elif new.get("lines", 0) < baseline.get("lines", 0):
244
+ score += 5
245
+
246
+ # Penalize increased complexity
247
+ if new.get("complexity", 0) > baseline.get("complexity", 0):
248
+ score -= 10
249
+ elif new.get("complexity", 0) < baseline.get("complexity", 0):
250
+ score += 10
251
+
252
+ # Reward adding type hints
253
+ if new.get("has_type_hints") and not baseline.get("has_type_hints"):
254
+ score += 15
255
+
256
+ # Reward adding docstrings
257
+ if new.get("has_docstring") and not baseline.get("has_docstring"):
258
+ score += 10
259
+
260
+ # Reward removing globals
261
+ if baseline.get("has_globals") and not new.get("has_globals"):
262
+ score += 15
263
+
264
+ # Reward fixing magic numbers
265
+ if baseline.get("has_magic_numbers") and not new.get("has_magic_numbers"):
266
+ score += 10
267
+
268
+ # Ensure score is in valid range
269
+ return max(0.0, min(100.0, score))
270
+
271
+ @property
272
+ def state(self) -> State:
273
+ """Get the current environment state."""
274
+ return self._state