File size: 5,656 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Unified-diff application utilities.

The Repair Agent submits a unified diff. We need a permissive applier
because LLM diffs are often malformed (wrong line numbers, missing
context, extra prose). We try the strict applier first, then fall
back to applying hunks via plain string replacement.

The agent may also submit a full Python script instead of a diff
(common when the model's diff format breaks). We detect this and
treat it as a complete replacement.
"""
from __future__ import annotations

import difflib
import re


_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")


def looks_like_full_script(text: str) -> bool:
    """Heuristic: text is probably a full python script, not a diff."""
    lines = text.lstrip().splitlines()
    if not lines:
        return False
    has_diff_header = any(
        line.startswith(("---", "+++", "@@")) for line in lines[:5]
    )
    if has_diff_header:
        return False
    # If we see two or more script-style markers in the first 30 lines,
    # treat as a full replacement script.
    head = "\n".join(lines[:30])
    hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
    return hits >= 2


def _strict_apply(broken_script: str, diff_text: str) -> str | None:
    """Apply a unified diff strictly. Returns None on any failure."""
    lines = broken_script.splitlines(keepends=True)
    out: list[str] = []
    diff_lines = diff_text.splitlines()
    i = 0
    src_idx = 0
    in_hunk = False
    hunk_old: list[str] = []
    hunk_new: list[str] = []

    while i < len(diff_lines):
        line = diff_lines[i]
        if line.startswith(("---", "+++")):
            i += 1
            continue
        if line.startswith("@@"):
            # Flush previous hunk
            if in_hunk:
                # Find the hunk_old block in the source starting at src_idx.
                target = "".join(hunk_old)
                source_remainder = "".join(lines[src_idx:])
                pos = source_remainder.find(target)
                if pos == -1:
                    return None
                out.append(source_remainder[:pos])
                out.append("".join(hunk_new))
                src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
                hunk_old, hunk_new = [], []
            in_hunk = True
            i += 1
            continue
        if in_hunk:
            if line.startswith("+"):
                hunk_new.append(line[1:] + "\n")
            elif line.startswith("-"):
                hunk_old.append(line[1:] + "\n")
            else:
                # context line
                ctx = line[1:] if line.startswith(" ") else line
                hunk_old.append(ctx + "\n")
                hunk_new.append(ctx + "\n")
        i += 1

    # Flush trailing hunk
    if in_hunk and (hunk_old or hunk_new):
        target = "".join(hunk_old)
        source_remainder = "".join(lines[src_idx:])
        pos = source_remainder.find(target)
        if pos == -1:
            return None
        out.append(source_remainder[:pos])
        out.append("".join(hunk_new))
        consumed = source_remainder[: pos + len(target)]
        src_idx += len(consumed.splitlines(keepends=True))

    out.append("".join(lines[src_idx:]))
    return "".join(out)


def _permissive_apply(broken_script: str, diff_text: str) -> str:
    """Apply a malformed diff by extracting (-,+) line pairs and doing
    a tolerant search-and-replace.
    """
    repaired = broken_script
    pairs: list[tuple[str, str]] = []
    lines = diff_text.splitlines()
    pending_minus: str | None = None

    for line in lines:
        if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
            pending_minus = None
            continue
        if line.startswith("-"):
            pending_minus = line[1:].strip()
        elif line.startswith("+") and pending_minus is not None:
            pairs.append((pending_minus, line[1:].strip()))
            pending_minus = None
        elif pending_minus is not None and not line.startswith(" "):
            # standalone deletion — skip in permissive mode (we can't
            # reliably know what to delete without context)
            pending_minus = None

    for old, new in pairs:
        if old and old in repaired:
            repaired = repaired.replace(old, new, 1)

    return repaired


def apply_unified_diff(broken_script: str, diff_text: str) -> str:
    """Try every strategy in order and return the first that produces a change.

    Strategies:
      1. If `diff_text` looks like a full script, return it directly.
      2. Try strict diff application.
      3. Fall back to permissive (-,+) line-pair replacement.
      4. As last resort, return the broken script unchanged.
    """
    diff_text = diff_text or ""
    if not diff_text.strip():
        return broken_script

    if looks_like_full_script(diff_text):
        return diff_text

    if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
        strict = _strict_apply(broken_script, diff_text)
        if strict is not None and strict != broken_script:
            return strict

    perm = _permissive_apply(broken_script, diff_text)
    return perm


def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
    """Produce a canonical unified diff from before -> after."""
    diff = difflib.unified_diff(
        before.splitlines(keepends=True),
        after.splitlines(keepends=True),
        fromfile=f"a/{path}",
        tofile=f"b/{path}",
        n=2,
    )
    return "".join(diff)