rl_code_fix_env / src /sandbox /patcher.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
import ast
import unidiff
import diff_match_patch as dmp_module
from dataclasses import dataclass
from typing import List, Tuple, Optional
def validate_python_syntax(code: str) -> Tuple[bool, Optional[str]]:
"""
Validate that code string is valid Python by parsing with AST.
Returns:
(is_valid, error_message)
"""
try:
ast.parse(code)
return True, None
except SyntaxError as e:
return False, f"SyntaxError: {e.msg} at line {e.lineno}, column {e.offset}"
except Exception as e:
return False, f"ParseError: {str(e)}"
@dataclass
class HunkResult:
hunk_index: int
source_file: str
applied: bool
confidence: float
location_found: int
failed_reason: Optional[str] = None
def apply_patch(
code: str,
diff: str,
match_threshold: float = 0.5,
match_distance: int = 2000,
) -> Tuple[str, List[HunkResult]]:
"""
Parse `diff` with unidiff (structured, typed hunk objects),
then apply each hunk via DMP's fuzzy Bitap engine.
Returns (patched_code, [HunkResult, ...]) so the RL reward
function gets per-hunk confidence scores, not just pass/fail.
"""
dmp_init = dmp_module.diff_match_patch()
dmp_init.Match_Threshold = match_threshold
dmp_init.Match_Distance = match_distance
dmp_init.Patch_DeleteThreshold = match_threshold
try:
patch_set = unidiff.PatchSet(diff)
except unidiff.UnidiffParseError as e:
raise ValueError(f"unidiff failed to parse the diff: {e}")
results: List[HunkResult] = []
curr_code = code
hunk_idx = 0
for patched_file in patch_set:
for hunk in patched_file:
prev_content = _reconstruct_from_hunk(hunk, include_added=False)
new_content = _reconstruct_from_hunk(hunk, include_removed=False)
# Pure insertion: no old text to match against
if not prev_content.strip():
char_hint = _line_to_char(curr_code, hunk.source_start - 1)
diffs = [(dmp_module.diff_match_patch.DIFF_INSERT, new_content)]
patches = dmp_init.patch_make("", diffs)
for patch in patches:
patch.start1 = char_hint
patch.start2 = char_hint
new_code, ok = dmp_init.patch_apply(patches, curr_code)
applied = ok[0] if ok else False
results.append(HunkResult(
hunk_index=hunk_idx,
source_file=patched_file.path,
applied=applied,
confidence=1.0 if applied else 0.0,
location_found=char_hint,
failed_reason=None if applied else "Pure insertion failed",
))
if applied:
curr_code = new_code
hunk_idx += 1
continue
# Convert unidiff 1-based line hint char offset for DMP search window
char_hint = _line_to_char(curr_code, hunk.source_start - 1)
# Bitap fuzzy match Bug 3 fixed: called on dmp_init instance
loc = dmp_init.match_main(curr_code, prev_content, char_hint)
confidence = 0.0
if loc != -1:
actual_slice = curr_code[loc: loc + len(prev_content)]
edits = dmp_init.diff_main(prev_content, actual_slice)
edit_chars = sum(
len(txt) for op, txt in edits
if op != dmp_module.diff_match_patch.DIFF_EQUAL
)
confidence = max(0.0, 1.0 - edit_chars / max(len(prev_content), 1))
else:
results.append(HunkResult(
hunk_index=hunk_idx,
source_file=patched_file.path,
applied=False,
confidence=0.0,
location_found=-1,
failed_reason="Bitap match failed context too stale",
))
hunk_idx += 1
continue
# Build patch against the ACTUAL slice found, not stale line numbers
actual_old = curr_code[loc: loc + len(prev_content)]
diffs = dmp_init.diff_main(actual_old, new_content)
dmp_init.diff_cleanupSemantic(diffs)
patches = dmp_init.patch_make(actual_old, diffs)
for p in patches:
p.start1 = loc
p.start2 = loc
new_code, apply_results = dmp_init.patch_apply(patches, curr_code)
applied = all(apply_results)
results.append(HunkResult(
hunk_index=hunk_idx,
source_file=patched_file.path,
applied=applied,
confidence=confidence if applied else 0.0,
location_found=loc,
failed_reason=None if applied else "patch_apply returned False",
))
if applied:
curr_code = new_code
hunk_idx += 1
# Validate the final patched code is valid Python
is_valid, error_msg = validate_python_syntax(curr_code)
if not is_valid:
# Return original code if patched code is invalid Python
return code, [HunkResult(
hunk_index=0,
source_file="validation",
applied=False,
confidence=0.0,
location_found=0,
failed_reason=f"Invalid Python after patch: {error_msg}",
)]
return curr_code, results
def _reconstruct_from_hunk(
hunk,
include_added: bool = True,
include_removed: bool = True,
) -> str:
res = ""
for line in hunk:
if line.line_type == ' ':
res += line.value
elif line.line_type == '-' and include_removed:
res += line.value
elif line.line_type == '+' and include_added:
res += line.value
return res
def _line_to_char(text: str, line_idx: int) -> int:
"""0-based line number character offset."""
lines = text.splitlines(keepends=True)
return sum(len(l) for l in lines[:max(0, line_idx)])