|
|
"""Standalone CodeLoader for loading and processing GitHub repositories.""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Callable |
|
|
|
|
|
import git |
|
|
import nbconvert |
|
|
import nbformat |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class CodeLoader: |
|
|
"""Load and process GitHub repositories for code analysis.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
github_url: str, |
|
|
max_file_size_mb: float = 1.0, |
|
|
raw_repo_dir: str | Path = "data/repos-raw", |
|
|
): |
|
|
logger.info( |
|
|
f"Initializing CodeLoader for {github_url} with max file size " |
|
|
f"{max_file_size_mb} MB and raw repo dir {raw_repo_dir}" |
|
|
) |
|
|
self.github_url = github_url |
|
|
self.max_file_size_mb = max_file_size_mb |
|
|
self.raw_repo_dir = Path(raw_repo_dir) |
|
|
self.repo_path = self.raw_repo_dir / self.github_url_to_repo_name |
|
|
|
|
|
self.clone_repo() |
|
|
self.files = self._get_files() |
|
|
|
|
|
@property |
|
|
def github_url_to_repo_name(self): |
|
|
"""Convert GitHub URL to a safe directory name.""" |
|
|
base_name = ( |
|
|
self.github_url.rstrip("/").split("/")[-2] |
|
|
+ "__" |
|
|
+ self.github_url.rstrip("/").split("/")[-1] |
|
|
) |
|
|
|
|
|
if base_name.endswith(".git"): |
|
|
base_name = base_name[:-4] |
|
|
return base_name |
|
|
|
|
|
def clone_repo(self): |
|
|
"""Clone or validate existing repository.""" |
|
|
if self.repo_path.exists(): |
|
|
logger.info(f"Repository already exists at {self.repo_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
repo = git.Repo(self.repo_path) |
|
|
|
|
|
try: |
|
|
_ = repo.head.commit.hexsha |
|
|
except (ValueError, git.BadName) as e: |
|
|
logger.warning( |
|
|
f"Repository has missing or corrupted commits at " |
|
|
f"{self.repo_path}, removing and re-cloning. Error: {e}" |
|
|
) |
|
|
shutil.rmtree(self.repo_path) |
|
|
self.clone_repo() |
|
|
return |
|
|
|
|
|
logger.info("Repository already exists and is valid") |
|
|
return |
|
|
|
|
|
except (git.InvalidGitRepositoryError, git.GitCommandError) as e: |
|
|
logger.warning( |
|
|
f"Invalid or corrupted git repository at {self.repo_path}, " |
|
|
f"removing and re-cloning. Error: {e}" |
|
|
) |
|
|
shutil.rmtree(self.repo_path) |
|
|
self.clone_repo() |
|
|
return |
|
|
|
|
|
|
|
|
logger.info(f"Cloning repo {self.github_url} to {self.repo_path}") |
|
|
self.raw_repo_dir.mkdir(parents=True, exist_ok=True) |
|
|
repo = git.Repo.clone_from(self.github_url, str(self.repo_path)) |
|
|
|
|
|
|
|
|
self._cleanup_repo() |
|
|
|
|
|
def _cleanup_repo(self): |
|
|
"""Remove docs/test directories, convert notebooks, and remove large files.""" |
|
|
|
|
|
for root, dirs, _ in os.walk(self.repo_path): |
|
|
|
|
|
if ".git" in dirs: |
|
|
dirs.remove(".git") |
|
|
|
|
|
|
|
|
dirs_to_remove = [ |
|
|
dir |
|
|
for dir in dirs |
|
|
if dir in ["docs", "doc", "test", "tests", "example", "examples"] |
|
|
] |
|
|
for dir in dirs_to_remove: |
|
|
dir_path = Path(root) / dir |
|
|
logger.info(f"Removing directory: {dir_path}") |
|
|
shutil.rmtree(dir_path) |
|
|
dirs.remove(dir) |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(self.repo_path): |
|
|
|
|
|
if ".git" in dirs: |
|
|
dirs.remove(".git") |
|
|
|
|
|
for file in files: |
|
|
if file.endswith(".ipynb"): |
|
|
logger.info(f"Converting Jupyter Notebook {file} to .py") |
|
|
try: |
|
|
nb = nbformat.read(Path(root) / file, as_version=4) |
|
|
|
|
|
for cell in nb.cells: |
|
|
if cell.get("cell_type") == "code": |
|
|
cell["outputs"] = [] |
|
|
cell["execution_count"] = None |
|
|
|
|
|
|
|
|
exporter = nbconvert.PythonExporter() |
|
|
source, _ = exporter.from_notebook_node(nb) |
|
|
source = ( |
|
|
"# This file was converted from a jupyter notebook " |
|
|
f"called {file}. All outputs have been removed.\n{source}" |
|
|
) |
|
|
with open(Path(root) / file.replace(".ipynb", ".py"), "w") as f: |
|
|
f.write(source) |
|
|
|
|
|
os.remove(Path(root) / file) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to convert notebook {file}: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(self.repo_path): |
|
|
|
|
|
if ".git" in dirs: |
|
|
dirs.remove(".git") |
|
|
|
|
|
for file in files: |
|
|
file_path = Path(root) / file |
|
|
try: |
|
|
file_size = file_path.stat().st_size |
|
|
except FileNotFoundError as e: |
|
|
logger.warning(f"Failed to get size of {file_path}: {e}") |
|
|
continue |
|
|
if file_size > self.mb_to_bytes(self.max_file_size_mb): |
|
|
logger.info(f"Removing large file: {file_path}") |
|
|
os.remove(file_path) |
|
|
|
|
|
def _get_files(self): |
|
|
"""Get all files from the repository.""" |
|
|
files = {} |
|
|
for root, _, _files in os.walk(self.repo_path): |
|
|
for file in _files: |
|
|
file_path = Path(root) / file |
|
|
if ".git" in str(file_path): |
|
|
continue |
|
|
|
|
|
|
|
|
file_path_key = file_path.relative_to(self.repo_path) |
|
|
|
|
|
try: |
|
|
with open(file_path, "r", encoding="utf-8", errors="ignore") as f: |
|
|
content = f.read() |
|
|
files[str(file_path_key)] = content |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not read {file_path}: {e}") |
|
|
|
|
|
|
|
|
files = dict(sorted(files.items())) |
|
|
return files |
|
|
|
|
|
@staticmethod |
|
|
def mb_to_bytes(mb: float) -> int: |
|
|
"""Convert megabytes to bytes.""" |
|
|
return int(mb * 1024 * 1024) |
|
|
|
|
|
def get_files_by_extension( |
|
|
self, extensions: list[str] | None = None |
|
|
) -> dict[str, str]: |
|
|
"""Get files filtered by extension.""" |
|
|
if extensions is None: |
|
|
|
|
|
extensions = [ |
|
|
".c", |
|
|
".cc", |
|
|
".cpp", |
|
|
".cu", |
|
|
".h", |
|
|
".hpp", |
|
|
".java", |
|
|
".jl", |
|
|
".m", |
|
|
".matlab", |
|
|
".Makefile", |
|
|
".md", |
|
|
".pl", |
|
|
".ps1", |
|
|
".py", |
|
|
".r", |
|
|
".sh", |
|
|
"config.txt", |
|
|
".rs", |
|
|
"readme.txt", |
|
|
"requirements_dev.txt", |
|
|
"requirements-dev.txt", |
|
|
"requirements.dev.txt", |
|
|
"requirements.txt", |
|
|
".scala", |
|
|
".yaml", |
|
|
".yml", |
|
|
] |
|
|
return { |
|
|
k: v |
|
|
for k, v in self.files.items() |
|
|
if k.lower().endswith(tuple(extensions)) |
|
|
} |
|
|
|
|
|
def get_repo_tree(self): |
|
|
"""Generate a tree representation of the repository.""" |
|
|
repo_tree = "" |
|
|
for root, dirs, files in os.walk(self.repo_path): |
|
|
|
|
|
if ".git" in dirs: |
|
|
dirs.remove(".git") |
|
|
|
|
|
level = str(Path(root).relative_to(self.repo_path)).count(os.sep) |
|
|
indent = "β " * (level - 1) + "βββ " if level > 0 else "" |
|
|
|
|
|
|
|
|
if level > 0: |
|
|
repo_tree += f"{indent}{Path(root).name}/\n" |
|
|
|
|
|
sub_indent = "β " * level + "βββ " |
|
|
for f in files: |
|
|
repo_tree += f"{sub_indent}{f}\n" |
|
|
return repo_tree |
|
|
|
|
|
def get_code_prompt( |
|
|
self, |
|
|
file_extensions: list[str] | None = None, |
|
|
token_counter: Callable | None = None, |
|
|
max_tokens: int | None = None, |
|
|
code_changes: list[dict[str, str]] | None = None, |
|
|
) -> str: |
|
|
"""Generate code prompt with repo tree and file contents.""" |
|
|
code_prompt = "Repo tree:\n" + self.get_repo_tree() + "\n\n" |
|
|
tokens = token_counter(code_prompt) if token_counter is not None else 0 |
|
|
|
|
|
if token_counter is not None and max_tokens is not None: |
|
|
logger.info( |
|
|
f"Building code prompt: repo tree tokens={tokens}, max_tokens={max_tokens}, " |
|
|
f"remaining for files={max_tokens - tokens}" |
|
|
) |
|
|
|
|
|
files_to_replace = {} |
|
|
if code_changes: |
|
|
files_to_replace = { |
|
|
cc["file_name"]: cc["discrepancy_code"] for cc in code_changes |
|
|
} |
|
|
logger.debug( |
|
|
f"Files to replace: {len(files_to_replace)}: {files_to_replace.keys()}" |
|
|
) |
|
|
|
|
|
for file_path, file_content in self.get_files_by_extension( |
|
|
file_extensions |
|
|
).items(): |
|
|
if file_path in files_to_replace: |
|
|
logger.debug(f"Replacing code for {file_path} with changed code") |
|
|
file_content = files_to_replace[file_path] |
|
|
code_file = f"# ---\n# File: {file_path}\n# Content:\n{file_content}\n" |
|
|
if token_counter is not None: |
|
|
logger.debug(f"Adding file: {file_path}") |
|
|
num_tokens = token_counter(code_file) |
|
|
|
|
|
if max_tokens and (tokens + num_tokens) > max_tokens: |
|
|
logger.warning( |
|
|
f"Truncating. Max tokens reached for {self.github_url}. " |
|
|
f"Current tokens: {tokens}, File tokens: {num_tokens}, " |
|
|
f"Max tokens for code is {max_tokens}" |
|
|
) |
|
|
break |
|
|
tokens += num_tokens |
|
|
logger.debug( |
|
|
f"Number of tokens in file: {num_tokens}. " |
|
|
f"Total number of tokens in code prompt: {tokens}" |
|
|
) |
|
|
code_prompt += code_file |
|
|
|
|
|
|
|
|
if token_counter is not None: |
|
|
final_code_tokens = token_counter(code_prompt) |
|
|
logger.info( |
|
|
f"Code prompt built: {final_code_tokens} tokens " |
|
|
f"(max was {max_tokens if max_tokens else 'unlimited'})" |
|
|
) |
|
|
|
|
|
return code_prompt |
|
|
|
|
|
|
|
|
|