Spaces:
Running
Running
| import ast | |
| import unidiff | |
| import diff_match_patch as dmp_module | |
| from dataclasses import dataclass | |
| from typing import List, Tuple, Optional | |
| def validate_python_syntax(code: str) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate that code string is valid Python by parsing with AST. | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| try: | |
| ast.parse(code) | |
| return True, None | |
| except SyntaxError as e: | |
| return False, f"SyntaxError: {e.msg} at line {e.lineno}, column {e.offset}" | |
| except Exception as e: | |
| return False, f"ParseError: {str(e)}" | |
| class HunkResult: | |
| hunk_index: int | |
| source_file: str | |
| applied: bool | |
| confidence: float | |
| location_found: int | |
| failed_reason: Optional[str] = None | |
| def apply_patch( | |
| code: str, | |
| diff: str, | |
| match_threshold: float = 0.5, | |
| match_distance: int = 2000, | |
| ) -> Tuple[str, List[HunkResult]]: | |
| """ | |
| Parse `diff` with unidiff (structured, typed hunk objects), | |
| then apply each hunk via DMP's fuzzy Bitap engine. | |
| Returns (patched_code, [HunkResult, ...]) so the RL reward | |
| function gets per-hunk confidence scores, not just pass/fail. | |
| """ | |
| dmp_init = dmp_module.diff_match_patch() | |
| dmp_init.Match_Threshold = match_threshold | |
| dmp_init.Match_Distance = match_distance | |
| dmp_init.Patch_DeleteThreshold = match_threshold | |
| try: | |
| patch_set = unidiff.PatchSet(diff) | |
| except unidiff.UnidiffParseError as e: | |
| raise ValueError(f"unidiff failed to parse the diff: {e}") | |
| results: List[HunkResult] = [] | |
| curr_code = code | |
| hunk_idx = 0 | |
| for patched_file in patch_set: | |
| for hunk in patched_file: | |
| prev_content = _reconstruct_from_hunk(hunk, include_added=False) | |
| new_content = _reconstruct_from_hunk(hunk, include_removed=False) | |
| # Pure insertion: no old text to match against | |
| if not prev_content.strip(): | |
| char_hint = _line_to_char(curr_code, hunk.source_start - 1) | |
| diffs = [(dmp_module.diff_match_patch.DIFF_INSERT, new_content)] | |
| patches = dmp_init.patch_make("", diffs) | |
| for patch in patches: | |
| patch.start1 = char_hint | |
| patch.start2 = char_hint | |
| new_code, ok = dmp_init.patch_apply(patches, curr_code) | |
| applied = ok[0] if ok else False | |
| results.append(HunkResult( | |
| hunk_index=hunk_idx, | |
| source_file=patched_file.path, | |
| applied=applied, | |
| confidence=1.0 if applied else 0.0, | |
| location_found=char_hint, | |
| failed_reason=None if applied else "Pure insertion failed", | |
| )) | |
| if applied: | |
| curr_code = new_code | |
| hunk_idx += 1 | |
| continue | |
| # Convert unidiff 1-based line hint char offset for DMP search window | |
| char_hint = _line_to_char(curr_code, hunk.source_start - 1) | |
| # Bitap fuzzy match Bug 3 fixed: called on dmp_init instance | |
| loc = dmp_init.match_main(curr_code, prev_content, char_hint) | |
| confidence = 0.0 | |
| if loc != -1: | |
| actual_slice = curr_code[loc: loc + len(prev_content)] | |
| edits = dmp_init.diff_main(prev_content, actual_slice) | |
| edit_chars = sum( | |
| len(txt) for op, txt in edits | |
| if op != dmp_module.diff_match_patch.DIFF_EQUAL | |
| ) | |
| confidence = max(0.0, 1.0 - edit_chars / max(len(prev_content), 1)) | |
| else: | |
| results.append(HunkResult( | |
| hunk_index=hunk_idx, | |
| source_file=patched_file.path, | |
| applied=False, | |
| confidence=0.0, | |
| location_found=-1, | |
| failed_reason="Bitap match failed context too stale", | |
| )) | |
| hunk_idx += 1 | |
| continue | |
| # Build patch against the ACTUAL slice found, not stale line numbers | |
| actual_old = curr_code[loc: loc + len(prev_content)] | |
| diffs = dmp_init.diff_main(actual_old, new_content) | |
| dmp_init.diff_cleanupSemantic(diffs) | |
| patches = dmp_init.patch_make(actual_old, diffs) | |
| for p in patches: | |
| p.start1 = loc | |
| p.start2 = loc | |
| new_code, apply_results = dmp_init.patch_apply(patches, curr_code) | |
| applied = all(apply_results) | |
| results.append(HunkResult( | |
| hunk_index=hunk_idx, | |
| source_file=patched_file.path, | |
| applied=applied, | |
| confidence=confidence if applied else 0.0, | |
| location_found=loc, | |
| failed_reason=None if applied else "patch_apply returned False", | |
| )) | |
| if applied: | |
| curr_code = new_code | |
| hunk_idx += 1 | |
| # Validate the final patched code is valid Python | |
| is_valid, error_msg = validate_python_syntax(curr_code) | |
| if not is_valid: | |
| # Return original code if patched code is invalid Python | |
| return code, [HunkResult( | |
| hunk_index=0, | |
| source_file="validation", | |
| applied=False, | |
| confidence=0.0, | |
| location_found=0, | |
| failed_reason=f"Invalid Python after patch: {error_msg}", | |
| )] | |
| return curr_code, results | |
| def _reconstruct_from_hunk( | |
| hunk, | |
| include_added: bool = True, | |
| include_removed: bool = True, | |
| ) -> str: | |
| res = "" | |
| for line in hunk: | |
| if line.line_type == ' ': | |
| res += line.value | |
| elif line.line_type == '-' and include_removed: | |
| res += line.value | |
| elif line.line_type == '+' and include_added: | |
| res += line.value | |
| return res | |
| def _line_to_char(text: str, line_idx: int) -> int: | |
| """0-based line number character offset.""" | |
| lines = text.splitlines(keepends=True) | |
| return sum(len(l) for l in lines[:max(0, line_idx)]) |