jira-to-code / inference.py
Navigam's picture
feat: expand task suite to 22 challenges and update reward signal mechanics
6392732
Raw
History Blame Contribute Delete
14.9 kB
# inference.py — ReAct Agent for Jira-to-Code Environment
#
# Architecture:
# Phase 1: Episodic Memory — persistent messages[] across the episode
# Phase 2: ReAct Pattern — "thought" key forces reasoning before action
# Phase 3: Robust Parsing — JSON extraction with markdown-fence stripping
# Phase 4: Self-Correction — negative rewards inject corrective prompts
# Phase 5: Multi-Task Loop — evaluates all 6 tasks in one run
import argparse
import json
import os
import re
import textwrap
import time
from typing import List, Optional
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# Our environment for local/direct testing
from server.env import JiraToCodeEnv
from src.jira_to_code.models import JiraCodeAction
# --- HACKATHON MANDATORY CONFIGURATION ---
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
BENCHMARK = "jira-to-code"
# MAX_STEPS is now dynamic based on task level
SUCCESS_SCORE_THRESHOLD = 0.9 # Account for step penalties
ALL_TASKS = list(JiraToCodeEnv.TASKS.keys())
MAX_HISTORY_MESSAGES = 30 # Context-window safety: trim if exceeded
MAX_RETRIES = 5 # Rate limit retry attempts
RETRY_BASE_DELAY = 2 # Base delay in seconds for exponential backoff
# --- SYSTEM PROMPT (ReAct + Reward-Aware) ---
SYSTEM_PROMPT = textwrap.dedent("""\
You are an expert software engineer resolving Jira tickets.
You operate in a sandboxed workspace. You can read files, write code, list files, run tests, and submit your solution.
## Rules
1. ALWAYS respond with ONLY a valid JSON object. No markdown fences, no explanations outside JSON.
2. You MUST include a "thought" key FIRST to reason about your plan before acting.
3. Work step-by-step: list files, read the code, understand the bug/requirement, write a fix, run tests, then submit.
4. If tests fail, carefully read the traceback and fix your code before re-submitting.
5. Only use "submit" when you are confident all tests will pass.
6. Be efficient — each step has a small penalty. Aim to solve in the fewest steps possible.
7. Read the test file to understand exactly what is expected before writing code.
## Valid action_types
- "list_files" — List all files in the workspace (file_path and content should be null)
- "read_file" — Read a file's contents (requires file_path, content should be null)
- "write_file" — Write/overwrite a file (requires file_path and content)
- "run_tests" — Run pytest on the workspace (file_path and content should be null)
- "submit" — Final submission, runs tests and ends the episode (file_path and content should be null)
## Reward Structure
- list_files / read_file: 0.01 (initial exploration)
- write_file: +0.05 (reward for taking action)
- run_tests (all pass): +0.5 | run_tests (partial): proportional | run_tests (crash): 0.01
- submit (all pass): +1.0 | submit (partial): proportional
- Every step: 0.01 minimum reward (be efficient!)
## JSON Schema
{
"thought": "Your reasoning about what to do next and why",
"action_type": "one of: list_files, read_file, write_file, run_tests, submit",
"file_path": "string or null",
"content": "string or null"
}
## Strategy Guide
1. First, list_files to see the workspace structure.
2. Read the test file to understand the exact expected behavior.
3. Read the source file to understand the current (buggy/incomplete) code.
4. Write the fix/implementation.
5. Run tests to verify.
6. If tests pass, submit. If not, read the error, fix, and retry.
""").strip()
# --- MANDATORY LOGGING FUNCTIONS ---
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
# --- PHASE 3: ROBUST JSON PARSING ---
def extract_json(raw_text: str) -> dict:
"""
Extract a JSON object from LLM output, handling:
- Markdown code fences (```json ... ```)
- Leading/trailing whitespace and text
- Nested braces via brace-counting
"""
cleaned = raw_text.strip()
cleaned = re.sub(r'^```(?:json)?\s*', '', cleaned)
cleaned = re.sub(r'\s*```\s*$', '', cleaned)
cleaned = cleaned.strip()
# Try direct parse first
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# Fallback: find the first balanced {...} block via brace counting
start = cleaned.find('{')
if start == -1:
raise ValueError("No JSON object found in response")
depth = 0
in_string = False
escape_next = False
for i in range(start, len(cleaned)):
c = cleaned[i]
if escape_next:
escape_next = False
continue
if c == '\\' and in_string:
escape_next = True
continue
if c == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return json.loads(cleaned[start:i + 1])
raise ValueError("Unbalanced braces in JSON")
def parse_action(raw_text: str) -> JiraCodeAction:
"""Parse LLM output into a JiraCodeAction, extracting JSON robustly."""
action_dict = extract_json(raw_text)
# Remove the 'thought' key — it's for reasoning only, not part of the action model
action_dict.pop("thought", None)
return JiraCodeAction(**action_dict)
# --- PHASE 1 & 2: BUILD OBSERVATION MESSAGE ---
def build_observation_message(step: int, obs, reward: float) -> str:
"""Format environment observation as a user message for the conversation history."""
parts = [
f"--- Step {step} Observation ---",
f"Ticket: {obs.jira_ticket}",
f"Files in workspace: {', '.join(obs.file_tree) if obs.file_tree else 'None'}",
]
if obs.current_file_content is not None:
parts.append(f"File Content:\n```\n{obs.current_file_content}\n```")
if obs.test_output:
parts.append(f"Test Output:\n```\n{obs.test_output}\n```")
if obs.error:
parts.append(f"Error: {obs.error}")
parts.append(f"Reward: {reward:.2f}")
parts.append("Respond with your next action as JSON.")
return "\n".join(parts)
def trim_history(messages: list, max_messages: int = MAX_HISTORY_MESSAGES) -> None:
"""Trim oldest non-system messages if history exceeds max to avoid context overflow."""
while len(messages) > max_messages:
# Keep index 0 (system prompt), remove index 1
messages.pop(1)
# --- MAIN AGENT LOOP FOR ONE TASK ---
def run_agent_episode(client: OpenAI, task_name: str) -> tuple:
"""
Run a full agent episode for one task.
Returns: (score, steps_taken, rewards, success)
"""
os.environ["JIRA_TASK_LEVEL"] = task_name
env = JiraToCodeEnv()
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
try:
obs = env.reset()
task_max_steps = 10 if "easy" in task_name else 20
# Phase 1: Episodic memory — persistent conversation history
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_observation_message(0, obs, 0.0)},
]
for step in range(1, task_max_steps + 1):
trim_history(messages)
# Call the LLM with rate-limit retry + exponential backoff
raw_text = None
for attempt in range(MAX_RETRIES):
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.2,
max_tokens=2048,
)
raw_text = (completion.choices[0].message.content or "").strip()
break # Success
except Exception as exc:
exc_str = str(exc)
is_rate_limit = "429" in exc_str or "rate" in exc_str.lower()
if is_rate_limit and attempt < MAX_RETRIES - 1:
delay = RETRY_BASE_DELAY * (2 ** attempt)
print(f" [RATE LIMIT] Retry {attempt + 1}/{MAX_RETRIES} in {delay}s...", flush=True)
time.sleep(delay)
continue
# Non-rate-limit error or final attempt — give up
messages.append({
"role": "user",
"content": f"API ERROR: {exc}. Please try again with a valid JSON action.",
})
log_step(step=step, action=f"API_ERROR: {exc}", reward=0.0, done=False, error=exc_str)
rewards.append(0.0)
steps_taken = step
break
if raw_text is None:
continue # Skip to next step if all retries failed
# Phase 1: Append assistant response to history
messages.append({"role": "assistant", "content": raw_text})
# Phase 3: Robust parsing with safe fallback
try:
action = parse_action(raw_text)
action_log = action.model_dump_json()
except Exception as exc:
# Parse failure — No-Op fallback + corrective injection
action = JiraCodeAction(action_type="list_files")
action_log = f"PARSE_ERROR: {exc}"
# Phase 4: Inject corrective message
messages.append({
"role": "user",
"content": (
f"ERROR: Your last response was not valid JSON.\n"
f"Parse error: {exc}\n"
f"You MUST respond with ONLY a valid JSON object. "
f"No markdown, no explanations.\nTry again."
),
})
# Take step in environment
obs, reward, done, _ = env.step(action)
error = obs.error
# Ensure individual step rewards are strictly positive (min 0.01)
reward = max(reward, 0.01)
rewards.append(reward)
steps_taken = step
# Escape newlines for single-line logging
safe_action_str = action_log.replace('\n', '\\n').replace('\r', '')
log_step(step=step, action=safe_action_str, reward=reward, done=done, error=error)
if done:
break
# Phase 1: Append observation to conversation history
obs_message = build_observation_message(step, obs, reward)
# Phase 4: Self-correction prompt injection on low/negative reward or error
if reward <= 0.01 or obs.error:
obs_message += (
f"\n\nLOW/NEGATIVE RESULT (reward={reward:.2f})."
f"\nCarefully analyze the error/test output above."
f"\nIdentify the root cause and write a fix."
f"\nDo NOT repeat the same action that just failed."
)
elif reward >= 0.4:
obs_message += (
"\n\nTests are passing! If all tests pass, use 'submit' to finalize."
)
messages.append({"role": "user", "content": obs_message})
# Calculate final score (clamp strictly between 0 and 1)
score = min(max(sum(rewards), 0.01), 0.99)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
env.close()
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score, steps_taken, rewards, success
# --- PHASE 5: MULTI-TASK EVALUATION ---
def main() -> None:
parser = argparse.ArgumentParser(description="Jira-to-Code ReAct Agent")
parser.add_argument(
"--tasks",
type=str,
default=None,
help=(
"Comma-separated list of tasks to run. "
f"Available: {', '.join(ALL_TASKS)}. "
"Default: all tasks."
),
)
args = parser.parse_args()
import random
# Determine which tasks to run
if args.tasks:
tasks = [t.strip() for t in args.tasks.split(",")]
invalid = [t for t in tasks if t not in ALL_TASKS]
if invalid:
print(f"ERROR: Unknown tasks: {invalid}", flush=True)
print(f"Available: {ALL_TASKS}", flush=True)
return
else:
# Baseline inference: 1 easy, 1 medium, 1 hard randomly sampled
easies = [t for t in ALL_TASKS if "easy" in t]
mediums = [t for t in ALL_TASKS if "medium" in t]
hards = [t for t in ALL_TASKS if "hard" in t]
tasks = []
if easies: tasks.append(random.choice(easies))
if mediums: tasks.append(random.choice(mediums))
if hards: tasks.append(random.choice(hards))
print(f"Running tasks: {tasks}", flush=True)
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
total_score = 0.0
results = []
for task in tasks:
score, steps, rewards, success = run_agent_episode(client, task)
results.append({
"task": task,
"score": score,
"steps": steps,
"success": success,
})
total_score += score
print("Waiting 20 seconds before next task to respect API limits...", flush=True)
time.sleep(20)
# Summary
print("\n" + "=" * 50, flush=True)
print("EVALUATION SUMMARY", flush=True)
print("=" * 50, flush=True)
for r in results:
status = "PASS" if r["success"] else "FAIL"
print(
f" {r['task']:10s} | score={r['score']:.3f} | "
f"steps={r['steps']:2d} | {status}",
flush=True,
)
avg_score = total_score / len(tasks)
print(f" {'AVERAGE':10s} | score={avg_score:.3f}", flush=True)
print(f" {'TOTAL':10s} | score={total_score:.3f} / {len(tasks):.1f}", flush=True)
print("=" * 50, flush=True)
if __name__ == "__main__":
main()