File size: 6,337 Bytes
03a907a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import ast
import unidiff
import diff_match_patch as dmp_module
from dataclasses import dataclass
from typing import List, Tuple, Optional


def validate_python_syntax(code: str) -> Tuple[bool, Optional[str]]:
    """

    Validate that code string is valid Python by parsing with AST.

    

    Returns:

        (is_valid, error_message)

    """
    try:
        ast.parse(code)
        return True, None
    except SyntaxError as e:
        return False, f"SyntaxError: {e.msg} at line {e.lineno}, column {e.offset}"
    except Exception as e:
        return False, f"ParseError: {str(e)}"


@dataclass
class HunkResult:
    hunk_index: int
    source_file: str
    applied: bool
    confidence: float
    location_found: int
    failed_reason: Optional[str] = None


def apply_patch(

    code: str,

    diff: str,

    match_threshold: float = 0.5,

    match_distance: int = 2000,

) -> Tuple[str, List[HunkResult]]:
    """

    Parse `diff` with unidiff (structured, typed hunk objects),

    then apply each hunk via DMP's fuzzy Bitap engine.



    Returns (patched_code, [HunkResult, ...]) so the RL reward

    function gets per-hunk confidence scores, not just pass/fail.

    """
    dmp_init = dmp_module.diff_match_patch()
    dmp_init.Match_Threshold = match_threshold
    dmp_init.Match_Distance = match_distance
    dmp_init.Patch_DeleteThreshold = match_threshold

    try:
        patch_set = unidiff.PatchSet(diff)
    except unidiff.UnidiffParseError as e:
        raise ValueError(f"unidiff failed to parse the diff: {e}")

    results: List[HunkResult] = []
    curr_code = code
    hunk_idx = 0

    for patched_file in patch_set:
        for hunk in patched_file:

            prev_content = _reconstruct_from_hunk(hunk, include_added=False)
            new_content  = _reconstruct_from_hunk(hunk, include_removed=False)

            # Pure insertion: no old text to match against
            if not prev_content.strip():
                char_hint = _line_to_char(curr_code, hunk.source_start - 1)
                diffs   = [(dmp_module.diff_match_patch.DIFF_INSERT, new_content)]
                patches = dmp_init.patch_make("", diffs)
                for patch in patches:
                    patch.start1 = char_hint
                    patch.start2 = char_hint

                new_code, ok = dmp_init.patch_apply(patches, curr_code)
                applied = ok[0] if ok else False

                results.append(HunkResult(
                    hunk_index=hunk_idx,
                    source_file=patched_file.path,
                    applied=applied,
                    confidence=1.0 if applied else 0.0,
                    location_found=char_hint,
                    failed_reason=None if applied else "Pure insertion failed",
                ))

                if applied:
                    curr_code = new_code

                hunk_idx += 1
                continue

            # Convert unidiff 1-based line hint  char offset for DMP search window
            char_hint = _line_to_char(curr_code, hunk.source_start - 1)

            # Bitap fuzzy match  Bug 3 fixed: called on dmp_init instance
            loc = dmp_init.match_main(curr_code, prev_content, char_hint)
            confidence = 0.0

            if loc != -1:
                actual_slice = curr_code[loc: loc + len(prev_content)]
                edits = dmp_init.diff_main(prev_content, actual_slice)
                edit_chars = sum(
                    len(txt) for op, txt in edits
                    if op != dmp_module.diff_match_patch.DIFF_EQUAL
                )
                confidence = max(0.0, 1.0 - edit_chars / max(len(prev_content), 1))
            else:
                results.append(HunkResult(
                    hunk_index=hunk_idx,
                    source_file=patched_file.path,
                    applied=False,
                    confidence=0.0,
                    location_found=-1,
                    failed_reason="Bitap match failed  context too stale",
                ))
                hunk_idx += 1
                continue

            # Build patch against the ACTUAL slice found, not stale line numbers
            actual_old = curr_code[loc: loc + len(prev_content)]
            
            diffs   = dmp_init.diff_main(actual_old, new_content)
            dmp_init.diff_cleanupSemantic(diffs)
            patches = dmp_init.patch_make(actual_old, diffs)
            for p in patches:
                p.start1 = loc
                p.start2 = loc

            new_code, apply_results = dmp_init.patch_apply(patches, curr_code)
            applied = all(apply_results)

            results.append(HunkResult(
                hunk_index=hunk_idx,
                source_file=patched_file.path,
                applied=applied,
                confidence=confidence if applied else 0.0,
                location_found=loc,
                failed_reason=None if applied else "patch_apply returned False",
            ))

            if applied:
                curr_code = new_code

            hunk_idx += 1

    # Validate the final patched code is valid Python
    is_valid, error_msg = validate_python_syntax(curr_code)
    if not is_valid:
        # Return original code if patched code is invalid Python
        return code, [HunkResult(
            hunk_index=0,
            source_file="validation",
            applied=False,
            confidence=0.0,
            location_found=0,
            failed_reason=f"Invalid Python after patch: {error_msg}",
        )]

    return curr_code, results


def _reconstruct_from_hunk(

    hunk,

    include_added: bool = True,

    include_removed: bool = True,

) -> str:
    res = ""
    for line in hunk:
        if line.line_type == ' ':
            res += line.value
        elif line.line_type == '-' and include_removed:
            res += line.value
        elif line.line_type == '+' and include_added:
            res += line.value
    return res


def _line_to_char(text: str, line_idx: int) -> int:
    """0-based line number  character offset."""
    lines = text.splitlines(keepends=True)
    return sum(len(l) for l in lines[:max(0, line_idx)])