File size: 10,572 Bytes
af83196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""
Utilities for code parsing, diffing, and manipulation
"""

import os
import re
from pathlib import Path
from typing import List, Optional, Set, Tuple


def apply_diff(original_solution: str, diff_text: str) -> str:
    """
    Apply a diff to the original code

    Args:
        original_solution: Original source solution
        diff_text: Diff in the SEARCH/REPLACE format

    Returns:
        Modified solution
    """
    # Split into lines for easier processing
    original_lines = original_solution.split("\n")
    result_lines = original_lines.copy()

    # Extract diff blocks
    diff_blocks = extract_diffs(diff_text)

    # Apply each diff block
    for search_text, replace_text in diff_blocks:
        search_lines = search_text.split("\n")
        replace_lines = replace_text.split("\n")

        # Find where the search pattern starts in the original solution
        for i in range(len(result_lines) - len(search_lines) + 1):
            if result_lines[i : i + len(search_lines)] == search_lines:
                # Replace the matched section
                result_lines[i : i + len(search_lines)] = replace_lines
                break

    return "\n".join(result_lines)


def extract_diffs(diff_text: str) -> List[Tuple[str, str]]:
    """
    Extract diff blocks from the diff text

    Args:
        diff_text: Diff in the SEARCH/REPLACE format

    Returns:
        List of tuples (search_text, replace_text)
    """
    diff_pattern = r"<<<<<<< SEARCH\n(.*?)=======\n(.*?)>>>>>>> REPLACE"
    diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL)
    return [(match[0].rstrip(), match[1].rstrip()) for match in diff_blocks]


def parse_full_rewrite(llm_response: str, language: str = "python") -> Optional[str]:
    """
    Extract a full rewrite from an LLM response

    Args:
        llm_response: Response from the LLM
        language: Programming language

    Returns:
        Extracted code or None if not found
    """
    solution_block_pattern = r"```" + language + r"\n(.*?)```"
    matches = re.findall(solution_block_pattern, llm_response, re.DOTALL)

    if matches:
        return matches[0].strip()

    # Fallback to any solution block
    solution_block_pattern = r"```(.*?)```"
    matches = re.findall(solution_block_pattern, llm_response, re.DOTALL)

    if matches:
        return matches[0].strip()

    # Fallback to plain text
    return llm_response


def _extract_def_info(solution: str) -> Optional[Tuple[str, str, Optional[str]]]:
    """
    Extract function/class name and docstring (or first comment as fallback) from solution block.

    Returns:
        Tuple of (kind, name, docstring_first_line) or None if not found
    """
    # Look for function definition
    func_match = re.search(r"^\s*def\s+(\w+)\s*\(", solution, re.MULTILINE)
    if func_match:
        name = func_match.group(1)
        # Try to extract docstring, fallback to first comment
        docstring = _extract_docstring(solution, func_match.end())
        if not docstring:
            docstring = _extract_first_comment(solution, func_match.start())
        return ("function", name, docstring)

    # Look for class definition
    class_match = re.search(r"^\s*class\s+(\w+)", solution, re.MULTILINE)
    if class_match:
        name = class_match.group(1)
        docstring = _extract_docstring(solution, class_match.end())
        if not docstring:
            docstring = _extract_first_comment(solution, class_match.start())
        return ("class", name, docstring)

    return None


def _extract_first_comment(solution: str, func_start: int) -> Optional[str]:
    """
    Extract consecutive comment lines inside a function/class body.
    Used as fallback when no docstring is available.
    Returns up to 5 lines of comments joined together.
    """
    remaining = solution[func_start:]
    colon_match = re.search(r"(?:\)|[^:]+):\s*\n", remaining)
    if not colon_match:
        return None

    # Get the body after the colon
    body_start = colon_match.end()
    body = remaining[body_start:]

    # Collect consecutive comment lines
    comment_lines = []
    lines = body.split("\n")
    for line in lines[:10]:  # Check first 10 lines for comments
        stripped = line.strip()
        if stripped.startswith("#"):
            # Remove the # and leading space
            comment_text = stripped[1:].strip()
            if comment_text:
                comment_lines.append(comment_text)
            if len(comment_lines) >= 5:  # Max 5 lines
                break
        elif stripped and not stripped.startswith("#"):
            # Hit actual code, stop collecting
            break

    return "\n".join(comment_lines) if comment_lines else None


def _extract_docstring(solution: str, start_pos: int) -> Optional[str]:
    """
    Extract first line of docstring after a given position.

    Args:
        solution: Source code
        start_pos: Position to start searching from
    """
    remaining = solution[start_pos:]
    docstring_match = re.search(r':\s*\n\s*("""|\'\'\')(.*?)("""|\'\'\')', remaining, re.DOTALL)

    if docstring_match:
        docstring_content = docstring_match.group(2).strip()
        return docstring_content.split("\n")[0].strip()

    return None


def format_diff_summary(diff_blocks: List[Tuple[str, str]]) -> str:
    """
    Create a human-readable summary of the diff.

    If docstrings are identical between old and new code, uses simpler format.
    If docstrings differ or function is renamed, shows the meaningful change.

    Args:
        diff_blocks: List of (search_text, replace_text) tuples

    Returns:
        Summary string
    """
    summary = []

    for i, (search_text, replace_text) in enumerate(diff_blocks):
        search_lines = search_text.strip().split("\n")
        replace_lines = replace_text.strip().split("\n")

        # Try to extract meaningful info from the solution
        old_info = _extract_def_info(search_text)
        new_info = _extract_def_info(replace_text)

        # Build a meaningful summary
        if old_info or new_info:
            info = new_info or old_info
            kind, name, docstring = info

            # Get docstrings from both to compare
            old_docstring = old_info[2] if old_info else None
            new_docstring = new_info[2] if new_info else None

            if old_info and new_info and old_info[1] != new_info[1]:
                # Renamed function/class - always show this
                desc = f"Renamed {old_info[0]} `{old_info[1]}` → `{new_info[1]}`"
            elif old_docstring and new_docstring and old_docstring != new_docstring:
                # Docstrings are DIFFERENT - show the new docstring
                desc = f"Modified {kind} `{name}`: {new_docstring}"
            elif old_docstring == new_docstring:
                # Docstrings are IDENTICAL - use simple format (just line counts)
                desc = f"Modified {kind} `{name}` ({len(search_lines)}{len(replace_lines)} lines)"
            elif docstring:
                # Only one has docstring
                desc = f"Modified {kind} `{name}`: {docstring}"
            else:
                desc = f"Modified {kind} `{name}` ({len(search_lines)}{len(replace_lines)} lines)"

            summary.append(f"Change {i+1}: {desc}")
        elif len(search_lines) == 1 and len(replace_lines) == 1:
            # Single line change - show the actual change
            summary.append(
                f"Change {i+1}: '{search_lines[0].strip()}' → '{replace_lines[0].strip()}'"
            )
        else:
            # Fallback: show first non-empty line as context
            first_old = next((l.strip() for l in search_lines if l.strip()), "")
            first_new = next((l.strip() for l in replace_lines if l.strip()), "")

            if first_old and first_new:
                summary.append(
                    f"Change {i+1}: Near `{first_old[:50]}...` ({len(search_lines)}{len(replace_lines)} lines)"
                )
            else:
                summary.append(
                    f"Change {i+1}: Replace {len(search_lines)} lines with {len(replace_lines)} lines"
                )

    return "\n".join(summary)


def extract_solution_language(solution: str) -> str:
    """
    Try to determine the language of a solution snippet in string format

    Args:
        solution: Solution snippet

    Returns:
        Detected language or "text" by default if no language is detected
    """
    # Look for common language signatures
    if re.search(r"^(import|from|def|class)\s", solution, re.MULTILINE):
        return "python"
    elif re.search(r"^(package|import java|public class)", solution, re.MULTILINE):
        return "java"
    elif re.search(r"^(#include|int main|void main)", solution, re.MULTILINE):
        return "cpp"
    elif re.search(r"^(function|var|let|const|console\.log)", solution, re.MULTILINE):
        return "javascript"
    elif re.search(r"^(module|fn|let mut|impl)", solution, re.MULTILINE):
        return "rust"
    elif re.search(r"^(SELECT|CREATE TABLE|INSERT INTO)", solution, re.MULTILINE):
        return "sql"

    return "text"


def build_repo_map(
    root: str,
    *,
    max_depth: int = 4,
    allowed_extensions: Tuple[str, ...] = (".py",),
    excluded_dirs: Tuple[str, ...] = (".git", "__pycache__"),
) -> str:
    """Return a depth-limited directory tree of *root* as a string.

    Only files whose extension is in *allowed_extensions* are included.
    Directories in *excluded_dirs* (and hidden directories) are skipped.
    Returns an empty string if *root* does not exist or is not a directory.
    """
    if not root or not os.path.isdir(root):
        return ""

    root_path = Path(root).resolve()
    excluded: Set[str] = set(excluded_dirs)
    allowed: Set[str] = set(allowed_extensions)
    lines: List[str] = []

    def walk(directory: Path, prefix: str, depth: int) -> None:
        if depth > max_depth:
            return
        try:
            entries = sorted(directory.iterdir(), key=lambda p: (p.is_file(), p.name))
        except PermissionError:
            return
        for entry in entries:
            if entry.name.startswith(".") or entry.name in excluded:
                continue
            if entry.is_dir():
                lines.append(f"{prefix}{entry.name}/")
                walk(entry, prefix + "  ", depth + 1)
            elif entry.suffix in allowed:
                lines.append(f"{prefix}{entry.name}")

    walk(root_path, "  ", 0)
    return "\n".join(lines)