Spaces:
Sleeping
Sleeping
File size: 12,289 Bytes
cd7967c 769cea2 cd7967c 769cea2 31f5053 cd7967c 769cea2 31f5053 cd7967c 31f5053 769cea2 cd7967c 769cea2 6392732 769cea2 cd7967c 769cea2 6392732 769cea2 cd7967c 6392732 769cea2 cd7967c 31f5053 769cea2 cd7967c 769cea2 cd7967c c19fcd5 31f5053 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 31f5053 769cea2 cd7967c 31f5053 c19fcd5 31f5053 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 6392732 769cea2 31f5053 769cea2 cd7967c 31f5053 769cea2 31f5053 c19fcd5 cd7967c c19fcd5 769cea2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | # server/env.py
import os
import re
import shutil
import tempfile
import subprocess
from pathlib import Path
from typing import Tuple, Dict, Any
import sys
from openenv.core.env_server import Environment, State
from src.jira_to_code.models import JiraCodeAction, JiraCodeObservation
class JiraToCodeEnv(Environment):
TASKS = {
"easy": {
"dir": "src/jira_to_code/tasks/easy",
"ticket": (
"TICKET-101: Fix the off-by-one bug in calculator.add() function. "
"It should correctly sum two numbers."
),
},
"easy_2": {
"dir": "src/jira_to_code/tasks/easy_2",
"ticket": (
"TICKET-102: Fix the bug in string_utils.count_vowels(). "
"It currently only counts lowercase vowels but should be case-insensitive."
),
},
"easy_3": {"dir": "src/jira_to_code/tasks/easy_3", "ticket": "TICKET-E3: The API endpoint crashes with a KeyError when a user payload doesn't contain an optional 'phone_number' field. Change dictionary indexing to .get() with a fallback."},
"easy_4": {"dir": "src/jira_to_code/tasks/easy_4", "ticket": "TICKET-E4: Off-by-One Pagination. get_page_bounds(page, size) misses the 10th item on page 1. Fix the math index logic."},
"easy_5": {"dir": "src/jira_to_code/tasks/easy_5", "ticket": "TICKET-E5: FastAPI Route Typo. Route signature is id instead of user_id. Fix the parameter mismatch."},
"medium": {
"dir": "src/jira_to_code/tasks/medium",
"ticket": (
"TICKET-201: Implement format_user_data in formatter.py. "
"It should format dictionary data to 'LAST_NAME, First_name (Age: X)'. "
"Handle missing age by defaulting to 'Unknown'."
),
},
"medium_2": {
"dir": "src/jira_to_code/tasks/medium_2",
"ticket": (
"TICKET-202: Implement validate_email() and validate_password() in validator.py. "
"Email: must have exactly one '@', at least 1 char before '@', a '.' after '@' with chars around it. "
"Password: at least 8 chars, one uppercase, one lowercase, one digit."
),
},
"medium_3": {"dir": "src/jira_to_code/tasks/medium_3", "ticket": "TICKET-M3: Missing Authentication Middleware. A sensitive endpoint (/api/billing) is exposed. Import @require_auth from auth.py and apply it to the route in routes.py."},
"medium_4": {"dir": "src/jira_to_code/tasks/medium_4", "ticket": "TICKET-M4: N+1 Database Problem. Rewrite the ORM query to use a JOIN (e.g., select_related)."},
"medium_5": {"dir": "src/jira_to_code/tasks/medium_5", "ticket": "TICKET-M5: Flawed Regex Validation. validate_email rejects emails with a plus sign. Update regex to allow user+test@gmail.com."},
"medium_6": {"dir": "src/jira_to_code/tasks/medium_6", "ticket": "TICKET-M6: Incomplete Error Handling. fetching currency rates crashes on timeout. Wrap in try/except and return a cached fallback value."},
"medium_7": {"dir": "src/jira_to_code/tasks/medium_7", "ticket": "TICKET-M7: Stale Cache Bug. update_user_profile updates DB but forgets to call redis.delete('user:id'). Invalidate the cache."},
"medium_8": {"dir": "src/jira_to_code/tasks/medium_8", "ticket": "TICKET-M8: Timezone Naive Conversion. Event scheduling function creates naive datetimes. Make them UTC aware."},
"medium_9": {"dir": "src/jira_to_code/tasks/medium_9", "ticket": "TICKET-M9: State Machine Loophole. Cart state machine allows CANCELLED to SHIPPED. Add transition guards."},
"medium_10": {"dir": "src/jira_to_code/tasks/medium_10", "ticket": "TICKET-M10: Config Merge Overwrite. YAML merge completely overwrites nested dictionaries. Fix recursion logic."},
"hard": {
"dir": "src/jira_to_code/tasks/hard",
"ticket": (
"TICKET-301: Implement an LRUCache class in lru_cache.py with put() and get() methods. "
"O(1) time complexity expected. Evict least recently used when capacity is reached."
),
},
"hard_2": {
"dir": "src/jira_to_code/tasks/hard_2",
"ticket": (
"TICKET-302: Implement a DirectedGraph class in graph.py with add_edge(), "
"has_path() (BFS/DFS), and topological_sort() methods. "
"topological_sort() must return an empty list if a cycle is detected."
),
},
"hard_3": {"dir": "src/jira_to_code/tasks/hard_3", "ticket": "TICKET-H3: Circular Dependency Resolution. models.py, utils.py, config.py. Extract shared logic into base.py."},
"hard_4": {"dir": "src/jira_to_code/tasks/hard_4", "ticket": "TICKET-H4: Race Condition in Thread Worker. Refactor the architecture to use queue.Queue."},
"hard_5": {"dir": "src/jira_to_code/tasks/hard_5", "ticket": "TICKET-H5: OOM Generator Fix. Readlines causes crash on 5GB file. Rewrite to yield generators."},
"hard_6": {"dir": "src/jira_to_code/tasks/hard_6", "ticket": "TICKET-H6: Implement Abstract Base Class. Implement StripeGateway matching PaymentGateway abstract class."},
"hard_7": {"dir": "src/jira_to_code/tasks/hard_7", "ticket": "TICKET-H7: Deadlock in Asyncio. Route acquires threading.Lock but forgets to release on exception. Use async context managers."},
}
# Reward configuration
STEP_PENALTY = -0.01 # Small penalty per step to encourage efficiency
GRACE_STEPS = 3 # No penalty for first N steps (orientation phase)
def __init__(self):
super().__init__()
self.step_count = 0
self.workspace_dir = None
self.task_level = "easy"
self.task_source_dir = None
self.jira_ticket = ""
def _get_file_tree(self) -> list[str]:
if not self.workspace_dir:
return []
tree = []
for root, _, files in os.walk(self.workspace_dir):
for file in files:
if "__pycache__" in root or file.endswith(".pyc"):
continue
rel_path = Path(root) / file
tree.append(str(rel_path.relative_to(self.workspace_dir)))
return tree
@staticmethod
def _parse_pytest_results(output: str) -> tuple[int, int]:
"""Extract (passed, total) from pytest output for partial-credit scoring."""
match_passed = re.search(r'(\d+) passed', output)
passed = int(match_passed.group(1)) if match_passed else 0
match_failed = re.search(r'(\d+) failed', output)
failed = int(match_failed.group(1)) if match_failed else 0
match_error = re.search(r'(\d+) error', output)
errors = int(match_error.group(1)) if match_error else 0
total = passed + failed + errors
return passed, max(total, 1)
def reset(self) -> JiraCodeObservation:
self.step_count = 0
if self.workspace_dir and Path(self.workspace_dir).exists():
shutil.rmtree(self.workspace_dir)
# Re-read task level from environment variable on every reset
self.task_level = os.getenv("JIRA_TASK_LEVEL", "medium").lower()
if self.task_level not in self.TASKS:
self.task_level = "easy"
self.task_source_dir = Path(self.TASKS[self.task_level]["dir"]).resolve()
self.jira_ticket = self.TASKS[self.task_level]["ticket"]
self.workspace_dir = tempfile.mkdtemp(prefix=f"jira_env_{self.task_level}_")
if self.task_source_dir.exists():
shutil.copytree(self.task_source_dir, self.workspace_dir, dirs_exist_ok=True)
else:
print(f"Warning: Task directory {self.task_source_dir} not found!")
return JiraCodeObservation(
jira_ticket=self.jira_ticket,
file_tree=self._get_file_tree(),
)
def step(self, action: JiraCodeAction) -> Tuple[JiraCodeObservation, float, bool, Dict[str, Any]]:
self.step_count += 1
reward = 0.0
done = False
current_file_content = None
test_output = None
error = None
workspace_path = Path(self.workspace_dir).resolve()
try:
if action.action_type == "list_files":
current_file_content = "\n".join(self._get_file_tree())
elif action.action_type in ["read_file", "write_file"]:
if not action.file_path:
error = "file_path must be provided for read/write actions."
else:
target_path = (workspace_path / action.file_path).resolve()
if not target_path.is_relative_to(workspace_path):
error = "Access denied: cannot access files outside workspace."
elif action.action_type == "read_file":
if target_path.exists():
current_file_content = target_path.read_text()
else:
error = f"File not found: {action.file_path}"
elif action.action_type == "write_file":
if action.content is None:
error = "content must be provided for write_file action."
else:
target_path.parent.mkdir(parents=True, exist_ok=True)
target_path.write_text(action.content)
current_file_content = action.content
reward = 0.05 # Small shaping reward for taking action
elif action.action_type == "run_tests":
result = subprocess.run(
[sys.executable, "-m", "pytest", "-v"],
cwd=self.workspace_dir,
capture_output=True, text=True, timeout=30,
)
test_output = result.stdout + "\n" + result.stderr
passed, total = self._parse_pytest_results(test_output)
if result.returncode == 0:
# All tests pass — strong positive signal
reward = 0.1 + 0.4 * (passed / total)
elif result.returncode == 1:
# Some tests fail — partial credit
reward = 0.1 * (passed / total)
else:
# Collection error / crash
reward = -0.1
elif action.action_type == "submit":
result = subprocess.run(
[sys.executable, "-m", "pytest", "-v"],
cwd=self.workspace_dir,
capture_output=True, text=True, timeout=30,
)
test_output = result.stdout + "\n" + result.stderr
passed, total = self._parse_pytest_results(test_output)
done = True
if result.returncode == 0:
reward = 1.0 # Full marks
else:
reward = 0.5 * (passed / total) # Partial credit on submit
except subprocess.TimeoutExpired:
error = "Tests timed out after 30 seconds."
test_output = "TIMEOUT"
reward = -0.1
except Exception as e:
error = f"System error: {str(e)}"
reward = -0.2
# Apply shaping rewards based on step count
if self.step_count <= 3:
reward += 0.02
else:
reward -= 0.01
# Enforce strictly bounded rewards for OpenEnv requirements (between 0.01 and 0.99)
reward = max(0.01, min(0.99, reward))
obs = JiraCodeObservation(
jira_ticket=self.jira_ticket,
file_tree=self._get_file_tree(),
current_file_content=current_file_content,
test_output=test_output,
error=error,
)
return obs, reward, done, {}
def state(self) -> State:
return State(
episode_id=f"jira-{self.task_level}-{self.step_count}",
step_count=self.step_count,
)
def close(self):
if self.workspace_dir and Path(self.workspace_dir).exists():
shutil.rmtree(self.workspace_dir) |