File size: 11,561 Bytes
4caa453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384bb2f
 
 
 
 
 
4caa453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384bb2f
 
 
 
 
 
 
 
4caa453
 
 
 
 
 
384bb2f
 
 
 
 
 
 
 
 
4caa453
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""Standalone CodeLoader for loading and processing GitHub repositories."""

import logging
import os
import shutil
from pathlib import Path
from typing import Callable

import git
import nbconvert
import nbformat

logger = logging.getLogger(__name__)


class CodeLoader:
    """Load and process GitHub repositories for code analysis."""

    def __init__(
        self,
        github_url: str,
        max_file_size_mb: float = 1.0,
        raw_repo_dir: str | Path = "data/repos-raw",
    ):
        logger.info(
            f"Initializing CodeLoader for {github_url} with max file size "
            f"{max_file_size_mb} MB and raw repo dir {raw_repo_dir}"
        )
        self.github_url = github_url
        self.max_file_size_mb = max_file_size_mb
        self.raw_repo_dir = Path(raw_repo_dir)
        self.repo_path = self.raw_repo_dir / self.github_url_to_repo_name

        self.clone_repo()
        self.files = self._get_files()

    @property
    def github_url_to_repo_name(self):
        """Convert GitHub URL to a safe directory name."""
        base_name = (
            self.github_url.rstrip("/").split("/")[-2]
            + "__"
            + self.github_url.rstrip("/").split("/")[-1]
        )
        # Remove .git suffix if present
        if base_name.endswith(".git"):
            base_name = base_name[:-4]
        return base_name

    def clone_repo(self):
        """Clone or validate existing repository."""
        if self.repo_path.exists():
            logger.info(f"Repository already exists at {self.repo_path}")

            # Validate repository integrity
            try:
                repo = git.Repo(self.repo_path)
                # Verify repository health
                try:
                    _ = repo.head.commit.hexsha
                except (ValueError, git.BadName) as e:
                    logger.warning(
                        f"Repository has missing or corrupted commits at "
                        f"{self.repo_path}, removing and re-cloning. Error: {e}"
                    )
                    shutil.rmtree(self.repo_path)
                    self.clone_repo()  # Recursive call to re-clone
                    return

                logger.info("Repository already exists and is valid")
                return

            except (git.InvalidGitRepositoryError, git.GitCommandError) as e:
                logger.warning(
                    f"Invalid or corrupted git repository at {self.repo_path}, "
                    f"removing and re-cloning. Error: {e}"
                )
                shutil.rmtree(self.repo_path)
                self.clone_repo()  # Recursive call to re-clone
                return

        # Clone the repository
        logger.info(f"Cloning repo {self.github_url} to {self.repo_path}")
        self.raw_repo_dir.mkdir(parents=True, exist_ok=True)
        repo = git.Repo.clone_from(self.github_url, str(self.repo_path))

        # Clean up the repository
        self._cleanup_repo()

    def _cleanup_repo(self):
        """Remove docs/test directories, convert notebooks, and remove large files."""
        # Remove docs/test directories
        for root, dirs, _ in os.walk(self.repo_path):
            # CRITICAL: Skip .git directory
            if ".git" in dirs:
                dirs.remove(".git")

            # Create a copy of dirs to avoid modification during iteration
            dirs_to_remove = [
                dir
                for dir in dirs
                if dir in ["docs", "doc", "test", "tests", "example", "examples"]
            ]
            for dir in dirs_to_remove:
                dir_path = Path(root) / dir
                logger.info(f"Removing directory: {dir_path}")
                shutil.rmtree(dir_path)
                dirs.remove(dir)

        # Convert Jupyter notebooks to Python files
        for root, dirs, files in os.walk(self.repo_path):
            # Skip .git directory
            if ".git" in dirs:
                dirs.remove(".git")

            for file in files:
                if file.endswith(".ipynb"):
                    logger.info(f"Converting Jupyter Notebook {file} to .py")
                    try:
                        nb = nbformat.read(Path(root) / file, as_version=4)
                        # Clear outputs
                        for cell in nb.cells:
                            if cell.get("cell_type") == "code":
                                cell["outputs"] = []
                                cell["execution_count"] = None

                        # Convert to .py
                        exporter = nbconvert.PythonExporter()
                        source, _ = exporter.from_notebook_node(nb)
                        source = (
                            "# This file was converted from a jupyter notebook "
                            f"called {file}. All outputs have been removed.\n{source}"
                        )
                        with open(Path(root) / file.replace(".ipynb", ".py"), "w") as f:
                            f.write(source)
                        # Remove the original notebook
                        os.remove(Path(root) / file)
                    except Exception as e:
                        logger.warning(f"Failed to convert notebook {file}: {e}")
                        raise e

        # Remove large files
        for root, dirs, files in os.walk(self.repo_path):
            # Skip .git directory
            if ".git" in dirs:
                dirs.remove(".git")

            for file in files:
                file_path = Path(root) / file
                try:
                    file_size = file_path.stat().st_size
                except FileNotFoundError as e:
                    logger.warning(f"Failed to get size of {file_path}: {e}")
                    continue
                if file_size > self.mb_to_bytes(self.max_file_size_mb):
                    logger.info(f"Removing large file: {file_path}")
                    os.remove(file_path)

    def _get_files(self):
        """Get all files from the repository."""
        files = {}
        for root, _, _files in os.walk(self.repo_path):
            for file in _files:
                file_path = Path(root) / file
                if ".git" in str(file_path):
                    continue

                # Get relative path from repo root
                file_path_key = file_path.relative_to(self.repo_path)

                try:
                    with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
                        content = f.read()
                        files[str(file_path_key)] = content
                except Exception as e:
                    logger.warning(f"Could not read {file_path}: {e}")

        # Order keys alphabetically
        files = dict(sorted(files.items()))
        return files

    @staticmethod
    def mb_to_bytes(mb: float) -> int:
        """Convert megabytes to bytes."""
        return int(mb * 1024 * 1024)

    def get_files_by_extension(
        self, extensions: list[str] | None = None
    ) -> dict[str, str]:
        """Get files filtered by extension."""
        if extensions is None:
            # Note: ipynb files are converted to .py during cleanup
            extensions = [
                ".c",
                ".cc",
                ".cpp",
                ".cu",
                ".h",
                ".hpp",
                ".java",
                ".jl",
                ".m",
                ".matlab",
                ".Makefile",
                ".md",
                ".pl",
                ".ps1",
                ".py",
                ".r",
                ".sh",
                "config.txt",
                ".rs",
                "readme.txt",
                "requirements_dev.txt",
                "requirements-dev.txt",
                "requirements.dev.txt",
                "requirements.txt",
                ".scala",
                ".yaml",
                ".yml",
            ]
        return {
            k: v
            for k, v in self.files.items()
            if k.lower().endswith(tuple(extensions))
        }

    def get_repo_tree(self):
        """Generate a tree representation of the repository."""
        repo_tree = ""
        for root, dirs, files in os.walk(self.repo_path):
            # Exclude the .git directory
            if ".git" in dirs:
                dirs.remove(".git")

            level = str(Path(root).relative_to(self.repo_path)).count(os.sep)
            indent = "β”‚   " * (level - 1) + "β”œβ”€β”€ " if level > 0 else ""

            # Don't print the starting path itself, just its contents
            if level > 0:
                repo_tree += f"{indent}{Path(root).name}/\n"

            sub_indent = "β”‚   " * level + "β”œβ”€β”€ "
            for f in files:
                repo_tree += f"{sub_indent}{f}\n"
        return repo_tree

    def get_code_prompt(
        self,
        file_extensions: list[str] | None = None,
        token_counter: Callable | None = None,
        max_tokens: int | None = None,
        code_changes: list[dict[str, str]] | None = None,
    ) -> str:
        """Generate code prompt with repo tree and file contents."""
        code_prompt = "Repo tree:\n" + self.get_repo_tree() + "\n\n"
        tokens = token_counter(code_prompt) if token_counter is not None else 0
        
        if token_counter is not None and max_tokens is not None:
            logger.info(
                f"Building code prompt: repo tree tokens={tokens}, max_tokens={max_tokens}, "
                f"remaining for files={max_tokens - tokens}"
            )

        files_to_replace = {}
        if code_changes:
            files_to_replace = {
                cc["file_name"]: cc["discrepancy_code"] for cc in code_changes
            }
            logger.debug(
                f"Files to replace: {len(files_to_replace)}: {files_to_replace.keys()}"
            )

        for file_path, file_content in self.get_files_by_extension(
            file_extensions
        ).items():
            if file_path in files_to_replace:
                logger.debug(f"Replacing code for {file_path} with changed code")
                file_content = files_to_replace[file_path]
            code_file = f"# ---\n# File: {file_path}\n# Content:\n{file_content}\n"
            if token_counter is not None:
                logger.debug(f"Adding file: {file_path}")
                num_tokens = token_counter(code_file)
                # Check if adding this file would exceed the limit BEFORE adding it
                if max_tokens and (tokens + num_tokens) > max_tokens:
                    logger.warning(
                        f"Truncating. Max tokens reached for {self.github_url}. "
                        f"Current tokens: {tokens}, File tokens: {num_tokens}, "
                        f"Max tokens for code is {max_tokens}"
                    )
                    break
                tokens += num_tokens
                logger.debug(
                    f"Number of tokens in file: {num_tokens}. "
                    f"Total number of tokens in code prompt: {tokens}"
                )
            code_prompt += code_file
        
        # Log final code prompt size
        if token_counter is not None:
            final_code_tokens = token_counter(code_prompt)
            logger.info(
                f"Code prompt built: {final_code_tokens} tokens "
                f"(max was {max_tokens if max_tokens else 'unlimited'})"
            )
        
        return code_prompt