Spaces:
Sleeping
Sleeping
File size: 14,908 Bytes
cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 17bfd8a 769cea2 6f3f8a4 cd7967c 6f3f8a4 769cea2 6392732 cd7967c 6392732 cd7967c 02fd062 cd7967c 02fd062 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 769cea2 cd7967c 6392732 cd7967c 769cea2 6392732 cd7967c 02fd062 cd7967c 769cea2 cd7967c 769cea2 02fd062 769cea2 cd7967c 769cea2 cd7967c 02fd062 cd7967c 02fd062 cd7967c 02fd062 cd7967c 02fd062 769cea2 cd7967c 769cea2 cd7967c 6392732 cd7967c 6392732 cd7967c 6392732 28b1067 cd7967c 769cea2 cd7967c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 | # inference.py β ReAct Agent for Jira-to-Code Environment
#
# Architecture:
# Phase 1: Episodic Memory β persistent messages[] across the episode
# Phase 2: ReAct Pattern β "thought" key forces reasoning before action
# Phase 3: Robust Parsing β JSON extraction with markdown-fence stripping
# Phase 4: Self-Correction β negative rewards inject corrective prompts
# Phase 5: Multi-Task Loop β evaluates all 6 tasks in one run
import argparse
import json
import os
import re
import textwrap
import time
from typing import List, Optional
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# Our environment for local/direct testing
from server.env import JiraToCodeEnv
from src.jira_to_code.models import JiraCodeAction
# --- HACKATHON MANDATORY CONFIGURATION ---
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
BENCHMARK = "jira-to-code"
# MAX_STEPS is now dynamic based on task level
SUCCESS_SCORE_THRESHOLD = 0.9 # Account for step penalties
ALL_TASKS = list(JiraToCodeEnv.TASKS.keys())
MAX_HISTORY_MESSAGES = 30 # Context-window safety: trim if exceeded
MAX_RETRIES = 5 # Rate limit retry attempts
RETRY_BASE_DELAY = 2 # Base delay in seconds for exponential backoff
# --- SYSTEM PROMPT (ReAct + Reward-Aware) ---
SYSTEM_PROMPT = textwrap.dedent("""\
You are an expert software engineer resolving Jira tickets.
You operate in a sandboxed workspace. You can read files, write code, list files, run tests, and submit your solution.
## Rules
1. ALWAYS respond with ONLY a valid JSON object. No markdown fences, no explanations outside JSON.
2. You MUST include a "thought" key FIRST to reason about your plan before acting.
3. Work step-by-step: list files, read the code, understand the bug/requirement, write a fix, run tests, then submit.
4. If tests fail, carefully read the traceback and fix your code before re-submitting.
5. Only use "submit" when you are confident all tests will pass.
6. Be efficient β each step has a small penalty. Aim to solve in the fewest steps possible.
7. Read the test file to understand exactly what is expected before writing code.
## Valid action_types
- "list_files" β List all files in the workspace (file_path and content should be null)
- "read_file" β Read a file's contents (requires file_path, content should be null)
- "write_file" β Write/overwrite a file (requires file_path and content)
- "run_tests" β Run pytest on the workspace (file_path and content should be null)
- "submit" β Final submission, runs tests and ends the episode (file_path and content should be null)
## Reward Structure
- list_files / read_file: 0.01 (initial exploration)
- write_file: +0.05 (reward for taking action)
- run_tests (all pass): +0.5 | run_tests (partial): proportional | run_tests (crash): 0.01
- submit (all pass): +1.0 | submit (partial): proportional
- Every step: 0.01 minimum reward (be efficient!)
## JSON Schema
{
"thought": "Your reasoning about what to do next and why",
"action_type": "one of: list_files, read_file, write_file, run_tests, submit",
"file_path": "string or null",
"content": "string or null"
}
## Strategy Guide
1. First, list_files to see the workspace structure.
2. Read the test file to understand the exact expected behavior.
3. Read the source file to understand the current (buggy/incomplete) code.
4. Write the fix/implementation.
5. Run tests to verify.
6. If tests pass, submit. If not, read the error, fix, and retry.
""").strip()
# --- MANDATORY LOGGING FUNCTIONS ---
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
# --- PHASE 3: ROBUST JSON PARSING ---
def extract_json(raw_text: str) -> dict:
"""
Extract a JSON object from LLM output, handling:
- Markdown code fences (```json ... ```)
- Leading/trailing whitespace and text
- Nested braces via brace-counting
"""
cleaned = raw_text.strip()
cleaned = re.sub(r'^```(?:json)?\s*', '', cleaned)
cleaned = re.sub(r'\s*```\s*$', '', cleaned)
cleaned = cleaned.strip()
# Try direct parse first
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# Fallback: find the first balanced {...} block via brace counting
start = cleaned.find('{')
if start == -1:
raise ValueError("No JSON object found in response")
depth = 0
in_string = False
escape_next = False
for i in range(start, len(cleaned)):
c = cleaned[i]
if escape_next:
escape_next = False
continue
if c == '\\' and in_string:
escape_next = True
continue
if c == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return json.loads(cleaned[start:i + 1])
raise ValueError("Unbalanced braces in JSON")
def parse_action(raw_text: str) -> JiraCodeAction:
"""Parse LLM output into a JiraCodeAction, extracting JSON robustly."""
action_dict = extract_json(raw_text)
# Remove the 'thought' key β it's for reasoning only, not part of the action model
action_dict.pop("thought", None)
return JiraCodeAction(**action_dict)
# --- PHASE 1 & 2: BUILD OBSERVATION MESSAGE ---
def build_observation_message(step: int, obs, reward: float) -> str:
"""Format environment observation as a user message for the conversation history."""
parts = [
f"--- Step {step} Observation ---",
f"Ticket: {obs.jira_ticket}",
f"Files in workspace: {', '.join(obs.file_tree) if obs.file_tree else 'None'}",
]
if obs.current_file_content is not None:
parts.append(f"File Content:\n```\n{obs.current_file_content}\n```")
if obs.test_output:
parts.append(f"Test Output:\n```\n{obs.test_output}\n```")
if obs.error:
parts.append(f"Error: {obs.error}")
parts.append(f"Reward: {reward:.2f}")
parts.append("Respond with your next action as JSON.")
return "\n".join(parts)
def trim_history(messages: list, max_messages: int = MAX_HISTORY_MESSAGES) -> None:
"""Trim oldest non-system messages if history exceeds max to avoid context overflow."""
while len(messages) > max_messages:
# Keep index 0 (system prompt), remove index 1
messages.pop(1)
# --- MAIN AGENT LOOP FOR ONE TASK ---
def run_agent_episode(client: OpenAI, task_name: str) -> tuple:
"""
Run a full agent episode for one task.
Returns: (score, steps_taken, rewards, success)
"""
os.environ["JIRA_TASK_LEVEL"] = task_name
env = JiraToCodeEnv()
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
try:
obs = env.reset()
task_max_steps = 10 if "easy" in task_name else 20
# Phase 1: Episodic memory β persistent conversation history
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_observation_message(0, obs, 0.0)},
]
for step in range(1, task_max_steps + 1):
trim_history(messages)
# Call the LLM with rate-limit retry + exponential backoff
raw_text = None
for attempt in range(MAX_RETRIES):
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.2,
max_tokens=2048,
)
raw_text = (completion.choices[0].message.content or "").strip()
break # Success
except Exception as exc:
exc_str = str(exc)
is_rate_limit = "429" in exc_str or "rate" in exc_str.lower()
if is_rate_limit and attempt < MAX_RETRIES - 1:
delay = RETRY_BASE_DELAY * (2 ** attempt)
print(f" [RATE LIMIT] Retry {attempt + 1}/{MAX_RETRIES} in {delay}s...", flush=True)
time.sleep(delay)
continue
# Non-rate-limit error or final attempt β give up
messages.append({
"role": "user",
"content": f"API ERROR: {exc}. Please try again with a valid JSON action.",
})
log_step(step=step, action=f"API_ERROR: {exc}", reward=0.0, done=False, error=exc_str)
rewards.append(0.0)
steps_taken = step
break
if raw_text is None:
continue # Skip to next step if all retries failed
# Phase 1: Append assistant response to history
messages.append({"role": "assistant", "content": raw_text})
# Phase 3: Robust parsing with safe fallback
try:
action = parse_action(raw_text)
action_log = action.model_dump_json()
except Exception as exc:
# Parse failure β No-Op fallback + corrective injection
action = JiraCodeAction(action_type="list_files")
action_log = f"PARSE_ERROR: {exc}"
# Phase 4: Inject corrective message
messages.append({
"role": "user",
"content": (
f"ERROR: Your last response was not valid JSON.\n"
f"Parse error: {exc}\n"
f"You MUST respond with ONLY a valid JSON object. "
f"No markdown, no explanations.\nTry again."
),
})
# Take step in environment
obs, reward, done, _ = env.step(action)
error = obs.error
# Ensure individual step rewards are strictly positive (min 0.01)
reward = max(reward, 0.01)
rewards.append(reward)
steps_taken = step
# Escape newlines for single-line logging
safe_action_str = action_log.replace('\n', '\\n').replace('\r', '')
log_step(step=step, action=safe_action_str, reward=reward, done=done, error=error)
if done:
break
# Phase 1: Append observation to conversation history
obs_message = build_observation_message(step, obs, reward)
# Phase 4: Self-correction prompt injection on low/negative reward or error
if reward <= 0.01 or obs.error:
obs_message += (
f"\n\nLOW/NEGATIVE RESULT (reward={reward:.2f})."
f"\nCarefully analyze the error/test output above."
f"\nIdentify the root cause and write a fix."
f"\nDo NOT repeat the same action that just failed."
)
elif reward >= 0.4:
obs_message += (
"\n\nTests are passing! If all tests pass, use 'submit' to finalize."
)
messages.append({"role": "user", "content": obs_message})
# Calculate final score (clamp strictly between 0 and 1)
score = min(max(sum(rewards), 0.01), 0.99)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
env.close()
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score, steps_taken, rewards, success
# --- PHASE 5: MULTI-TASK EVALUATION ---
def main() -> None:
parser = argparse.ArgumentParser(description="Jira-to-Code ReAct Agent")
parser.add_argument(
"--tasks",
type=str,
default=None,
help=(
"Comma-separated list of tasks to run. "
f"Available: {', '.join(ALL_TASKS)}. "
"Default: all tasks."
),
)
args = parser.parse_args()
import random
# Determine which tasks to run
if args.tasks:
tasks = [t.strip() for t in args.tasks.split(",")]
invalid = [t for t in tasks if t not in ALL_TASKS]
if invalid:
print(f"ERROR: Unknown tasks: {invalid}", flush=True)
print(f"Available: {ALL_TASKS}", flush=True)
return
else:
# Baseline inference: 1 easy, 1 medium, 1 hard randomly sampled
easies = [t for t in ALL_TASKS if "easy" in t]
mediums = [t for t in ALL_TASKS if "medium" in t]
hards = [t for t in ALL_TASKS if "hard" in t]
tasks = []
if easies: tasks.append(random.choice(easies))
if mediums: tasks.append(random.choice(mediums))
if hards: tasks.append(random.choice(hards))
print(f"Running tasks: {tasks}", flush=True)
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
total_score = 0.0
results = []
for task in tasks:
score, steps, rewards, success = run_agent_episode(client, task)
results.append({
"task": task,
"score": score,
"steps": steps,
"success": success,
})
total_score += score
print("Waiting 20 seconds before next task to respect API limits...", flush=True)
time.sleep(20)
# Summary
print("\n" + "=" * 50, flush=True)
print("EVALUATION SUMMARY", flush=True)
print("=" * 50, flush=True)
for r in results:
status = "PASS" if r["success"] else "FAIL"
print(
f" {r['task']:10s} | score={r['score']:.3f} | "
f"steps={r['steps']:2d} | {status}",
flush=True,
)
avg_score = total_score / len(tasks)
print(f" {'AVERAGE':10s} | score={avg_score:.3f}", flush=True)
print(f" {'TOTAL':10s} | score={total_score:.3f} / {len(tasks):.1f}", flush=True)
print("=" * 50, flush=True)
if __name__ == "__main__":
main() |