File size: 10,572 Bytes
af83196 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 | """
Utilities for code parsing, diffing, and manipulation
"""
import os
import re
from pathlib import Path
from typing import List, Optional, Set, Tuple
def apply_diff(original_solution: str, diff_text: str) -> str:
"""
Apply a diff to the original code
Args:
original_solution: Original source solution
diff_text: Diff in the SEARCH/REPLACE format
Returns:
Modified solution
"""
# Split into lines for easier processing
original_lines = original_solution.split("\n")
result_lines = original_lines.copy()
# Extract diff blocks
diff_blocks = extract_diffs(diff_text)
# Apply each diff block
for search_text, replace_text in diff_blocks:
search_lines = search_text.split("\n")
replace_lines = replace_text.split("\n")
# Find where the search pattern starts in the original solution
for i in range(len(result_lines) - len(search_lines) + 1):
if result_lines[i : i + len(search_lines)] == search_lines:
# Replace the matched section
result_lines[i : i + len(search_lines)] = replace_lines
break
return "\n".join(result_lines)
def extract_diffs(diff_text: str) -> List[Tuple[str, str]]:
"""
Extract diff blocks from the diff text
Args:
diff_text: Diff in the SEARCH/REPLACE format
Returns:
List of tuples (search_text, replace_text)
"""
diff_pattern = r"<<<<<<< SEARCH\n(.*?)=======\n(.*?)>>>>>>> REPLACE"
diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL)
return [(match[0].rstrip(), match[1].rstrip()) for match in diff_blocks]
def parse_full_rewrite(llm_response: str, language: str = "python") -> Optional[str]:
"""
Extract a full rewrite from an LLM response
Args:
llm_response: Response from the LLM
language: Programming language
Returns:
Extracted code or None if not found
"""
solution_block_pattern = r"```" + language + r"\n(.*?)```"
matches = re.findall(solution_block_pattern, llm_response, re.DOTALL)
if matches:
return matches[0].strip()
# Fallback to any solution block
solution_block_pattern = r"```(.*?)```"
matches = re.findall(solution_block_pattern, llm_response, re.DOTALL)
if matches:
return matches[0].strip()
# Fallback to plain text
return llm_response
def _extract_def_info(solution: str) -> Optional[Tuple[str, str, Optional[str]]]:
"""
Extract function/class name and docstring (or first comment as fallback) from solution block.
Returns:
Tuple of (kind, name, docstring_first_line) or None if not found
"""
# Look for function definition
func_match = re.search(r"^\s*def\s+(\w+)\s*\(", solution, re.MULTILINE)
if func_match:
name = func_match.group(1)
# Try to extract docstring, fallback to first comment
docstring = _extract_docstring(solution, func_match.end())
if not docstring:
docstring = _extract_first_comment(solution, func_match.start())
return ("function", name, docstring)
# Look for class definition
class_match = re.search(r"^\s*class\s+(\w+)", solution, re.MULTILINE)
if class_match:
name = class_match.group(1)
docstring = _extract_docstring(solution, class_match.end())
if not docstring:
docstring = _extract_first_comment(solution, class_match.start())
return ("class", name, docstring)
return None
def _extract_first_comment(solution: str, func_start: int) -> Optional[str]:
"""
Extract consecutive comment lines inside a function/class body.
Used as fallback when no docstring is available.
Returns up to 5 lines of comments joined together.
"""
remaining = solution[func_start:]
colon_match = re.search(r"(?:\)|[^:]+):\s*\n", remaining)
if not colon_match:
return None
# Get the body after the colon
body_start = colon_match.end()
body = remaining[body_start:]
# Collect consecutive comment lines
comment_lines = []
lines = body.split("\n")
for line in lines[:10]: # Check first 10 lines for comments
stripped = line.strip()
if stripped.startswith("#"):
# Remove the # and leading space
comment_text = stripped[1:].strip()
if comment_text:
comment_lines.append(comment_text)
if len(comment_lines) >= 5: # Max 5 lines
break
elif stripped and not stripped.startswith("#"):
# Hit actual code, stop collecting
break
return "\n".join(comment_lines) if comment_lines else None
def _extract_docstring(solution: str, start_pos: int) -> Optional[str]:
"""
Extract first line of docstring after a given position.
Args:
solution: Source code
start_pos: Position to start searching from
"""
remaining = solution[start_pos:]
docstring_match = re.search(r':\s*\n\s*("""|\'\'\')(.*?)("""|\'\'\')', remaining, re.DOTALL)
if docstring_match:
docstring_content = docstring_match.group(2).strip()
return docstring_content.split("\n")[0].strip()
return None
def format_diff_summary(diff_blocks: List[Tuple[str, str]]) -> str:
"""
Create a human-readable summary of the diff.
If docstrings are identical between old and new code, uses simpler format.
If docstrings differ or function is renamed, shows the meaningful change.
Args:
diff_blocks: List of (search_text, replace_text) tuples
Returns:
Summary string
"""
summary = []
for i, (search_text, replace_text) in enumerate(diff_blocks):
search_lines = search_text.strip().split("\n")
replace_lines = replace_text.strip().split("\n")
# Try to extract meaningful info from the solution
old_info = _extract_def_info(search_text)
new_info = _extract_def_info(replace_text)
# Build a meaningful summary
if old_info or new_info:
info = new_info or old_info
kind, name, docstring = info
# Get docstrings from both to compare
old_docstring = old_info[2] if old_info else None
new_docstring = new_info[2] if new_info else None
if old_info and new_info and old_info[1] != new_info[1]:
# Renamed function/class - always show this
desc = f"Renamed {old_info[0]} `{old_info[1]}` → `{new_info[1]}`"
elif old_docstring and new_docstring and old_docstring != new_docstring:
# Docstrings are DIFFERENT - show the new docstring
desc = f"Modified {kind} `{name}`: {new_docstring}"
elif old_docstring == new_docstring:
# Docstrings are IDENTICAL - use simple format (just line counts)
desc = f"Modified {kind} `{name}` ({len(search_lines)}→{len(replace_lines)} lines)"
elif docstring:
# Only one has docstring
desc = f"Modified {kind} `{name}`: {docstring}"
else:
desc = f"Modified {kind} `{name}` ({len(search_lines)}→{len(replace_lines)} lines)"
summary.append(f"Change {i+1}: {desc}")
elif len(search_lines) == 1 and len(replace_lines) == 1:
# Single line change - show the actual change
summary.append(
f"Change {i+1}: '{search_lines[0].strip()}' → '{replace_lines[0].strip()}'"
)
else:
# Fallback: show first non-empty line as context
first_old = next((l.strip() for l in search_lines if l.strip()), "")
first_new = next((l.strip() for l in replace_lines if l.strip()), "")
if first_old and first_new:
summary.append(
f"Change {i+1}: Near `{first_old[:50]}...` ({len(search_lines)}→{len(replace_lines)} lines)"
)
else:
summary.append(
f"Change {i+1}: Replace {len(search_lines)} lines with {len(replace_lines)} lines"
)
return "\n".join(summary)
def extract_solution_language(solution: str) -> str:
"""
Try to determine the language of a solution snippet in string format
Args:
solution: Solution snippet
Returns:
Detected language or "text" by default if no language is detected
"""
# Look for common language signatures
if re.search(r"^(import|from|def|class)\s", solution, re.MULTILINE):
return "python"
elif re.search(r"^(package|import java|public class)", solution, re.MULTILINE):
return "java"
elif re.search(r"^(#include|int main|void main)", solution, re.MULTILINE):
return "cpp"
elif re.search(r"^(function|var|let|const|console\.log)", solution, re.MULTILINE):
return "javascript"
elif re.search(r"^(module|fn|let mut|impl)", solution, re.MULTILINE):
return "rust"
elif re.search(r"^(SELECT|CREATE TABLE|INSERT INTO)", solution, re.MULTILINE):
return "sql"
return "text"
def build_repo_map(
root: str,
*,
max_depth: int = 4,
allowed_extensions: Tuple[str, ...] = (".py",),
excluded_dirs: Tuple[str, ...] = (".git", "__pycache__"),
) -> str:
"""Return a depth-limited directory tree of *root* as a string.
Only files whose extension is in *allowed_extensions* are included.
Directories in *excluded_dirs* (and hidden directories) are skipped.
Returns an empty string if *root* does not exist or is not a directory.
"""
if not root or not os.path.isdir(root):
return ""
root_path = Path(root).resolve()
excluded: Set[str] = set(excluded_dirs)
allowed: Set[str] = set(allowed_extensions)
lines: List[str] = []
def walk(directory: Path, prefix: str, depth: int) -> None:
if depth > max_depth:
return
try:
entries = sorted(directory.iterdir(), key=lambda p: (p.is_file(), p.name))
except PermissionError:
return
for entry in entries:
if entry.name.startswith(".") or entry.name in excluded:
continue
if entry.is_dir():
lines.append(f"{prefix}{entry.name}/")
walk(entry, prefix + " ", depth + 1)
elif entry.suffix in allowed:
lines.append(f"{prefix}{entry.name}")
walk(root_path, " ", 0)
return "\n".join(lines)
|