"""Unified-diff application utilities. The Repair Agent submits a unified diff. We need a permissive applier because LLM diffs are often malformed (wrong line numbers, missing context, extra prose). We try the strict applier first, then fall back to applying hunks via plain string replacement. The agent may also submit a full Python script instead of a diff (common when the model's diff format breaks). We detect this and treat it as a complete replacement. """ from __future__ import annotations import difflib import re _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) _SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(") def looks_like_full_script(text: str) -> bool: """Heuristic: text is probably a full python script, not a diff.""" lines = text.lstrip().splitlines() if not lines: return False has_diff_header = any( line.startswith(("---", "+++", "@@")) for line in lines[:5] ) if has_diff_header: return False # If we see two or more script-style markers in the first 30 lines, # treat as a full replacement script. head = "\n".join(lines[:30]) hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head) return hits >= 2 def _strict_apply(broken_script: str, diff_text: str) -> str | None: """Apply a unified diff strictly. Returns None on any failure.""" lines = broken_script.splitlines(keepends=True) out: list[str] = [] diff_lines = diff_text.splitlines() i = 0 src_idx = 0 in_hunk = False hunk_old: list[str] = [] hunk_new: list[str] = [] while i < len(diff_lines): line = diff_lines[i] if line.startswith(("---", "+++")): i += 1 continue if line.startswith("@@"): # Flush previous hunk if in_hunk: # Find the hunk_old block in the source starting at src_idx. target = "".join(hunk_old) source_remainder = "".join(lines[src_idx:]) pos = source_remainder.find(target) if pos == -1: return None out.append(source_remainder[:pos]) out.append("".join(hunk_new)) src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True)) hunk_old, hunk_new = [], [] in_hunk = True i += 1 continue if in_hunk: if line.startswith("+"): hunk_new.append(line[1:] + "\n") elif line.startswith("-"): hunk_old.append(line[1:] + "\n") else: # context line ctx = line[1:] if line.startswith(" ") else line hunk_old.append(ctx + "\n") hunk_new.append(ctx + "\n") i += 1 # Flush trailing hunk if in_hunk and (hunk_old or hunk_new): target = "".join(hunk_old) source_remainder = "".join(lines[src_idx:]) pos = source_remainder.find(target) if pos == -1: return None out.append(source_remainder[:pos]) out.append("".join(hunk_new)) consumed = source_remainder[: pos + len(target)] src_idx += len(consumed.splitlines(keepends=True)) out.append("".join(lines[src_idx:])) return "".join(out) def _permissive_apply(broken_script: str, diff_text: str) -> str: """Apply a malformed diff by extracting (-,+) line pairs and doing a tolerant search-and-replace. """ repaired = broken_script pairs: list[tuple[str, str]] = [] lines = diff_text.splitlines() pending_minus: str | None = None for line in lines: if line.startswith("---") or line.startswith("+++") or line.startswith("@@"): pending_minus = None continue if line.startswith("-"): pending_minus = line[1:].strip() elif line.startswith("+") and pending_minus is not None: pairs.append((pending_minus, line[1:].strip())) pending_minus = None elif pending_minus is not None and not line.startswith(" "): # standalone deletion — skip in permissive mode (we can't # reliably know what to delete without context) pending_minus = None for old, new in pairs: if old and old in repaired: repaired = repaired.replace(old, new, 1) return repaired def apply_unified_diff(broken_script: str, diff_text: str) -> str: """Try every strategy in order and return the first that produces a change. Strategies: 1. If `diff_text` looks like a full script, return it directly. 2. Try strict diff application. 3. Fall back to permissive (-,+) line-pair replacement. 4. As last resort, return the broken script unchanged. """ diff_text = diff_text or "" if not diff_text.strip(): return broken_script if looks_like_full_script(diff_text): return diff_text if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text: strict = _strict_apply(broken_script, diff_text) if strict is not None and strict != broken_script: return strict perm = _permissive_apply(broken_script, diff_text) return perm def make_unified_diff(before: str, after: str, path: str = "train.py") -> str: """Produce a canonical unified diff from before -> after.""" diff = difflib.unified_diff( before.splitlines(keepends=True), after.splitlines(keepends=True), fromfile=f"a/{path}", tofile=f"b/{path}", n=2, ) return "".join(diff)