sky2 / skydiscover /utils /code_utils.py
JustinTX's picture
Add files using upload-large-folder tool
af83196 verified
"""
Utilities for code parsing, diffing, and manipulation
"""
import os
import re
from pathlib import Path
from typing import List, Optional, Set, Tuple
def apply_diff(original_solution: str, diff_text: str) -> str:
"""
Apply a diff to the original code
Args:
original_solution: Original source solution
diff_text: Diff in the SEARCH/REPLACE format
Returns:
Modified solution
"""
# Split into lines for easier processing
original_lines = original_solution.split("\n")
result_lines = original_lines.copy()
# Extract diff blocks
diff_blocks = extract_diffs(diff_text)
# Apply each diff block
for search_text, replace_text in diff_blocks:
search_lines = search_text.split("\n")
replace_lines = replace_text.split("\n")
# Find where the search pattern starts in the original solution
for i in range(len(result_lines) - len(search_lines) + 1):
if result_lines[i : i + len(search_lines)] == search_lines:
# Replace the matched section
result_lines[i : i + len(search_lines)] = replace_lines
break
return "\n".join(result_lines)
def extract_diffs(diff_text: str) -> List[Tuple[str, str]]:
"""
Extract diff blocks from the diff text
Args:
diff_text: Diff in the SEARCH/REPLACE format
Returns:
List of tuples (search_text, replace_text)
"""
diff_pattern = r"<<<<<<< SEARCH\n(.*?)=======\n(.*?)>>>>>>> REPLACE"
diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL)
return [(match[0].rstrip(), match[1].rstrip()) for match in diff_blocks]
def parse_full_rewrite(llm_response: str, language: str = "python") -> Optional[str]:
"""
Extract a full rewrite from an LLM response
Args:
llm_response: Response from the LLM
language: Programming language
Returns:
Extracted code or None if not found
"""
solution_block_pattern = r"```" + language + r"\n(.*?)```"
matches = re.findall(solution_block_pattern, llm_response, re.DOTALL)
if matches:
return matches[0].strip()
# Fallback to any solution block
solution_block_pattern = r"```(.*?)```"
matches = re.findall(solution_block_pattern, llm_response, re.DOTALL)
if matches:
return matches[0].strip()
# Fallback to plain text
return llm_response
def _extract_def_info(solution: str) -> Optional[Tuple[str, str, Optional[str]]]:
"""
Extract function/class name and docstring (or first comment as fallback) from solution block.
Returns:
Tuple of (kind, name, docstring_first_line) or None if not found
"""
# Look for function definition
func_match = re.search(r"^\s*def\s+(\w+)\s*\(", solution, re.MULTILINE)
if func_match:
name = func_match.group(1)
# Try to extract docstring, fallback to first comment
docstring = _extract_docstring(solution, func_match.end())
if not docstring:
docstring = _extract_first_comment(solution, func_match.start())
return ("function", name, docstring)
# Look for class definition
class_match = re.search(r"^\s*class\s+(\w+)", solution, re.MULTILINE)
if class_match:
name = class_match.group(1)
docstring = _extract_docstring(solution, class_match.end())
if not docstring:
docstring = _extract_first_comment(solution, class_match.start())
return ("class", name, docstring)
return None
def _extract_first_comment(solution: str, func_start: int) -> Optional[str]:
"""
Extract consecutive comment lines inside a function/class body.
Used as fallback when no docstring is available.
Returns up to 5 lines of comments joined together.
"""
remaining = solution[func_start:]
colon_match = re.search(r"(?:\)|[^:]+):\s*\n", remaining)
if not colon_match:
return None
# Get the body after the colon
body_start = colon_match.end()
body = remaining[body_start:]
# Collect consecutive comment lines
comment_lines = []
lines = body.split("\n")
for line in lines[:10]: # Check first 10 lines for comments
stripped = line.strip()
if stripped.startswith("#"):
# Remove the # and leading space
comment_text = stripped[1:].strip()
if comment_text:
comment_lines.append(comment_text)
if len(comment_lines) >= 5: # Max 5 lines
break
elif stripped and not stripped.startswith("#"):
# Hit actual code, stop collecting
break
return "\n".join(comment_lines) if comment_lines else None
def _extract_docstring(solution: str, start_pos: int) -> Optional[str]:
"""
Extract first line of docstring after a given position.
Args:
solution: Source code
start_pos: Position to start searching from
"""
remaining = solution[start_pos:]
docstring_match = re.search(r':\s*\n\s*("""|\'\'\')(.*?)("""|\'\'\')', remaining, re.DOTALL)
if docstring_match:
docstring_content = docstring_match.group(2).strip()
return docstring_content.split("\n")[0].strip()
return None
def format_diff_summary(diff_blocks: List[Tuple[str, str]]) -> str:
"""
Create a human-readable summary of the diff.
If docstrings are identical between old and new code, uses simpler format.
If docstrings differ or function is renamed, shows the meaningful change.
Args:
diff_blocks: List of (search_text, replace_text) tuples
Returns:
Summary string
"""
summary = []
for i, (search_text, replace_text) in enumerate(diff_blocks):
search_lines = search_text.strip().split("\n")
replace_lines = replace_text.strip().split("\n")
# Try to extract meaningful info from the solution
old_info = _extract_def_info(search_text)
new_info = _extract_def_info(replace_text)
# Build a meaningful summary
if old_info or new_info:
info = new_info or old_info
kind, name, docstring = info
# Get docstrings from both to compare
old_docstring = old_info[2] if old_info else None
new_docstring = new_info[2] if new_info else None
if old_info and new_info and old_info[1] != new_info[1]:
# Renamed function/class - always show this
desc = f"Renamed {old_info[0]} `{old_info[1]}` → `{new_info[1]}`"
elif old_docstring and new_docstring and old_docstring != new_docstring:
# Docstrings are DIFFERENT - show the new docstring
desc = f"Modified {kind} `{name}`: {new_docstring}"
elif old_docstring == new_docstring:
# Docstrings are IDENTICAL - use simple format (just line counts)
desc = f"Modified {kind} `{name}` ({len(search_lines)}{len(replace_lines)} lines)"
elif docstring:
# Only one has docstring
desc = f"Modified {kind} `{name}`: {docstring}"
else:
desc = f"Modified {kind} `{name}` ({len(search_lines)}{len(replace_lines)} lines)"
summary.append(f"Change {i+1}: {desc}")
elif len(search_lines) == 1 and len(replace_lines) == 1:
# Single line change - show the actual change
summary.append(
f"Change {i+1}: '{search_lines[0].strip()}' → '{replace_lines[0].strip()}'"
)
else:
# Fallback: show first non-empty line as context
first_old = next((l.strip() for l in search_lines if l.strip()), "")
first_new = next((l.strip() for l in replace_lines if l.strip()), "")
if first_old and first_new:
summary.append(
f"Change {i+1}: Near `{first_old[:50]}...` ({len(search_lines)}{len(replace_lines)} lines)"
)
else:
summary.append(
f"Change {i+1}: Replace {len(search_lines)} lines with {len(replace_lines)} lines"
)
return "\n".join(summary)
def extract_solution_language(solution: str) -> str:
"""
Try to determine the language of a solution snippet in string format
Args:
solution: Solution snippet
Returns:
Detected language or "text" by default if no language is detected
"""
# Look for common language signatures
if re.search(r"^(import|from|def|class)\s", solution, re.MULTILINE):
return "python"
elif re.search(r"^(package|import java|public class)", solution, re.MULTILINE):
return "java"
elif re.search(r"^(#include|int main|void main)", solution, re.MULTILINE):
return "cpp"
elif re.search(r"^(function|var|let|const|console\.log)", solution, re.MULTILINE):
return "javascript"
elif re.search(r"^(module|fn|let mut|impl)", solution, re.MULTILINE):
return "rust"
elif re.search(r"^(SELECT|CREATE TABLE|INSERT INTO)", solution, re.MULTILINE):
return "sql"
return "text"
def build_repo_map(
root: str,
*,
max_depth: int = 4,
allowed_extensions: Tuple[str, ...] = (".py",),
excluded_dirs: Tuple[str, ...] = (".git", "__pycache__"),
) -> str:
"""Return a depth-limited directory tree of *root* as a string.
Only files whose extension is in *allowed_extensions* are included.
Directories in *excluded_dirs* (and hidden directories) are skipped.
Returns an empty string if *root* does not exist or is not a directory.
"""
if not root or not os.path.isdir(root):
return ""
root_path = Path(root).resolve()
excluded: Set[str] = set(excluded_dirs)
allowed: Set[str] = set(allowed_extensions)
lines: List[str] = []
def walk(directory: Path, prefix: str, depth: int) -> None:
if depth > max_depth:
return
try:
entries = sorted(directory.iterdir(), key=lambda p: (p.is_file(), p.name))
except PermissionError:
return
for entry in entries:
if entry.name.startswith(".") or entry.name in excluded:
continue
if entry.is_dir():
lines.append(f"{prefix}{entry.name}/")
walk(entry, prefix + " ", depth + 1)
elif entry.suffix in allowed:
lines.append(f"{prefix}{entry.name}")
walk(root_path, " ", 0)
return "\n".join(lines)