| """Unified-diff application utilities. |
| |
| The Repair Agent submits a unified diff. We need a permissive applier |
| because LLM diffs are often malformed (wrong line numbers, missing |
| context, extra prose). We try the strict applier first, then fall |
| back to applying hunks via plain string replacement. |
| |
| The agent may also submit a full Python script instead of a diff |
| (common when the model's diff format breaks). We detect this and |
| treat it as a complete replacement. |
| """ |
| from __future__ import annotations |
|
|
| import difflib |
| import re |
|
|
|
|
| _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) |
| _SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(") |
|
|
|
|
| def looks_like_full_script(text: str) -> bool: |
| """Heuristic: text is probably a full python script, not a diff.""" |
| lines = text.lstrip().splitlines() |
| if not lines: |
| return False |
| has_diff_header = any( |
| line.startswith(("---", "+++", "@@")) for line in lines[:5] |
| ) |
| if has_diff_header: |
| return False |
| |
| |
| head = "\n".join(lines[:30]) |
| hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head) |
| return hits >= 2 |
|
|
|
|
| def _strict_apply(broken_script: str, diff_text: str) -> str | None: |
| """Apply a unified diff strictly. Returns None on any failure.""" |
| lines = broken_script.splitlines(keepends=True) |
| out: list[str] = [] |
| diff_lines = diff_text.splitlines() |
| i = 0 |
| src_idx = 0 |
| in_hunk = False |
| hunk_old: list[str] = [] |
| hunk_new: list[str] = [] |
|
|
| while i < len(diff_lines): |
| line = diff_lines[i] |
| if line.startswith(("---", "+++")): |
| i += 1 |
| continue |
| if line.startswith("@@"): |
| |
| if in_hunk: |
| |
| target = "".join(hunk_old) |
| source_remainder = "".join(lines[src_idx:]) |
| pos = source_remainder.find(target) |
| if pos == -1: |
| return None |
| out.append(source_remainder[:pos]) |
| out.append("".join(hunk_new)) |
| src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True)) |
| hunk_old, hunk_new = [], [] |
| in_hunk = True |
| i += 1 |
| continue |
| if in_hunk: |
| if line.startswith("+"): |
| hunk_new.append(line[1:] + "\n") |
| elif line.startswith("-"): |
| hunk_old.append(line[1:] + "\n") |
| else: |
| |
| ctx = line[1:] if line.startswith(" ") else line |
| hunk_old.append(ctx + "\n") |
| hunk_new.append(ctx + "\n") |
| i += 1 |
|
|
| |
| if in_hunk and (hunk_old or hunk_new): |
| target = "".join(hunk_old) |
| source_remainder = "".join(lines[src_idx:]) |
| pos = source_remainder.find(target) |
| if pos == -1: |
| return None |
| out.append(source_remainder[:pos]) |
| out.append("".join(hunk_new)) |
| consumed = source_remainder[: pos + len(target)] |
| src_idx += len(consumed.splitlines(keepends=True)) |
|
|
| out.append("".join(lines[src_idx:])) |
| return "".join(out) |
|
|
|
|
| def _permissive_apply(broken_script: str, diff_text: str) -> str: |
| """Apply a malformed diff by extracting (-,+) line pairs and doing |
| a tolerant search-and-replace. |
| """ |
| repaired = broken_script |
| pairs: list[tuple[str, str]] = [] |
| lines = diff_text.splitlines() |
| pending_minus: str | None = None |
|
|
| for line in lines: |
| if line.startswith("---") or line.startswith("+++") or line.startswith("@@"): |
| pending_minus = None |
| continue |
| if line.startswith("-"): |
| pending_minus = line[1:].strip() |
| elif line.startswith("+") and pending_minus is not None: |
| pairs.append((pending_minus, line[1:].strip())) |
| pending_minus = None |
| elif pending_minus is not None and not line.startswith(" "): |
| |
| |
| pending_minus = None |
|
|
| for old, new in pairs: |
| if old and old in repaired: |
| repaired = repaired.replace(old, new, 1) |
|
|
| return repaired |
|
|
|
|
| def apply_unified_diff(broken_script: str, diff_text: str) -> str: |
| """Try every strategy in order and return the first that produces a change. |
| |
| Strategies: |
| 1. If `diff_text` looks like a full script, return it directly. |
| 2. Try strict diff application. |
| 3. Fall back to permissive (-,+) line-pair replacement. |
| 4. As last resort, return the broken script unchanged. |
| """ |
| diff_text = diff_text or "" |
| if not diff_text.strip(): |
| return broken_script |
|
|
| if looks_like_full_script(diff_text): |
| return diff_text |
|
|
| if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text: |
| strict = _strict_apply(broken_script, diff_text) |
| if strict is not None and strict != broken_script: |
| return strict |
|
|
| perm = _permissive_apply(broken_script, diff_text) |
| return perm |
|
|
|
|
| def make_unified_diff(before: str, after: str, path: str = "train.py") -> str: |
| """Produce a canonical unified diff from before -> after.""" |
| diff = difflib.unified_diff( |
| before.splitlines(keepends=True), |
| after.splitlines(keepends=True), |
| fromfile=f"a/{path}", |
| tofile=f"b/{path}", |
| n=2, |
| ) |
| return "".join(diff) |
|
|