3v324v23's picture
feat: refresh dashboard UX and backend integration
dcee3d3
import uuid
import difflib
import re
from models import PRAction, PRObservation, PRState, ReviewDecision
from server.tasks import single_pass, iterative, escalation, custom
from server import graders
TASKS = {
"single-pass-review": single_pass.TASK,
"iterative-negotiation": iterative.TASK,
"escalation-judgment": escalation.TASK,
"custom-review": custom.TASK,
}
class PRReviewEnvironment:
def __init__(self):
self._state = None
self._task = None
self._rewards = []
self._current_diff = ""
self._initial_diff = ""
def _extract_code(self, text: str) -> str:
"""Extracts python code from markdown triple backticks if present."""
match = re.search(r"```python\n(.*?)\n```", text, re.DOTALL)
if match:
return match.group(1).strip()
# Fallback to any backticks
match = re.search(r"```\n(.*?)\n```", text, re.DOTALL)
if match:
return match.group(1).strip()
return None
def _generate_unified_diff(self, old_code: str, new_code: str, filename: str = "file.py") -> str:
"""Generates a standard unified diff string between two versions of code."""
old_lines = old_code.splitlines(keepends=True)
new_lines = new_code.splitlines(keepends=True)
diff = difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{filename}", tofile=f"b/{filename}"
)
return "".join(diff)
def _get_base_code(self, diff_text: str) -> str:
"""Heuristic to extract the 'result' of a diff or just the text if it's a snippet."""
if not any(x in diff_text for x in ["--- ", "+++ ", "@@ "]):
return diff_text.strip()
# If it's a real diff, we try to reconstruct the NEW state (all context + all additions)
lines = diff_text.splitlines()
result_lines = []
for l in lines:
if l.startswith("--- ") or l.startswith("+++ ") or l.startswith("@@ ") or l.startswith("index "):
continue
if l.startswith("-"):
continue
if l.startswith("+"):
result_lines.append(l[1:])
elif l.startswith(" "):
result_lines.append(l[1:])
else:
result_lines.append(l)
return "\n".join(result_lines).strip()
def reset(self, task_name: str = "single-pass-review") -> PRObservation:
self._task = TASKS[task_name]
self._rewards = []
self._initial_diff = self._task["diff"]
self._current_diff = self._task["diff"]
self._state = PRState(
episode_id=str(uuid.uuid4()),
task_name=task_name,
turn=0,
max_turns=self._task["max_turns"],
review_history=[],
done=False,
success=False,
cumulative_reward=0.0,
)
return PRObservation(
turn=0,
diff=self._current_diff,
pr_title=self._task["pr_title"],
pr_description=self._task["pr_description"],
review_history=[],
author_response=None,
done=False,
message="New PR ready for review. Read the diff carefully. Identify the root cause of any issues, not just the symptom. Submit your decision.",
)
def step(self, action: PRAction) -> tuple[PRObservation, float, bool, dict]:
assert self._state is not None, "Call reset() first"
assert not self._state.done, "Episode is already done"
t = self._state
task = self._task
gt = task["ground_truth"]
turn = t.turn + 1
correct_key = f"correct_decision_turn_{turn}" if f"correct_decision_turn_{turn}" in gt else "correct_decision"
correct_decision = gt.get(correct_key, gt.get("correct_decision", "request_changes"))
author_responses = task.get("author_responses", [])
bug_still_present = correct_decision != ReviewDecision.APPROVE.value
reward = graders.compute_step_reward(
action=action,
correct_decision=correct_decision,
root_cause_keywords=gt.get("root_cause_keywords", []),
correct_issue_category=gt.get("correct_issue_category", "logic"),
bug_still_present=bug_still_present and action.decision == ReviewDecision.APPROVE,
turn=turn,
max_turns=task["max_turns"],
symptom_only_keywords=gt.get("symptom_only_keywords"),
false_fix_keywords=gt.get("false_fix_keywords"),
escalation_required=gt.get("escalation_required", False) and correct_decision == ReviewDecision.ESCALATE.value,
)
self._rewards.append(reward)
t.cumulative_reward = round(sum(self._rewards), 2)
t.turn = turn
t.review_history.append({"role": "reviewer", "content": f"[{action.decision.value}] {action.comment}"})
done = (
turn >= task["max_turns"]
or action.decision == ReviewDecision.APPROVE
or action.decision == ReviewDecision.ESCALATE
)
t.done = done
author_resp = None
if not done and turn <= len(author_responses):
author_resp = author_responses[turn - 1]
t.review_history.append({"role": "author", "content": author_resp})
# --- DYNAMIC DIFF INJECTION ---
proposed_fix = self._extract_code(author_resp)
if proposed_fix:
# Compare the fix against the INITIAL buggy state to generate a fresh Red/Green diff
base_code = self._get_base_code(self._initial_diff)
self._current_diff = self._generate_unified_diff(base_code, proposed_fix)
if done:
final_score = graders.compute_final_score(self._rewards, task["max_turns"])
t.success = final_score >= 0.5
message = f"Episode complete. Final score: {final_score:.3f}"
else:
message = "Author has responded. Re-read the diff. Has the actual root cause been addressed, or just the symptom?"
return PRObservation(
turn=turn,
diff=self._current_diff,
pr_title=task["pr_title"],
pr_description=task["pr_description"],
review_history=list(t.review_history),
author_response=author_resp,
done=done,
message=message,
), reward, done, {"episode_id": t.episode_id, "task": t.task_name}
def state(self) -> PRState:
return self._state
def get_rewards(self):
return self._rewards
def get_final_score(self):
return graders.compute_final_score(self._rewards, self._task["max_turns"])