code-crawler / sage /data_manager.py
juliaturc's picture
Add an LLM retriever (#70)
27dd60a
raw
history blame
10.7 kB
"""Utility classes to maniuplate GitHub repositories."""
import logging
import os
from abc import abstractmethod
from functools import cached_property
from typing import Any, Dict, Generator, Tuple
import requests
from git import GitCommandError, Repo
class DataManager:
def __init__(self, dataset_id: str):
self.dataset_id = dataset_id
@abstractmethod
def download(self) -> bool:
"""Downloads the data from a remote location."""
@abstractmethod
def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
"""Yields a tuple of (data, metadata) for each data item in the dataset."""
class GitHubRepoManager(DataManager):
"""Class to manage a local clone of a GitHub repository."""
def __init__(
self,
repo_id: str,
commit_hash: str = None,
access_token: str = None,
local_dir: str = None,
inclusion_file: str = None,
exclusion_file: str = None,
):
"""
Args:
repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
access_token: A GitHub access token to use for cloning private repositories. Not needed for public repos.
local_dir: The local directory where the repository will be cloned.
inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of
the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
"""
super().__init__(dataset_id=repo_id)
self.repo_id = repo_id
self.commit_hash = commit_hash
self.access_token = access_token
self.local_dir = local_dir or "/tmp/"
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
self.local_path = os.path.join(self.local_dir, repo_id)
self.log_dir = os.path.join(self.local_dir, "logs", repo_id)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
if inclusion_file and exclusion_file:
raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")
self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None
self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None
@cached_property
def is_public(self) -> bool:
"""Checks whether a GitHub repository is publicly visible."""
response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
# Note that the response will be 404 for both private and non-existent repos.
return response.status_code == 200
@cached_property
def default_branch(self) -> str:
"""Fetches the default branch of the repository from GitHub."""
headers = {
"Accept": "application/vnd.github.v3+json",
}
if self.access_token:
headers["Authorization"] = f"token {self.access_token}"
response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
if response.status_code == 200:
branch = response.json().get("default_branch", "main")
else:
# This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
# most common naming for the default branch ("main").
logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
branch = "main"
return branch
def download(self) -> bool:
"""Clones the repository to the local directory, if it's not already cloned."""
if os.path.exists(self.local_path):
# The repository is already cloned.
return True
if not self.is_public and not self.access_token:
raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.")
if self.access_token:
clone_url = f"https://{self.access_token}@github.com/{self.repo_id}.git"
else:
clone_url = f"https://github.com/{self.repo_id}.git"
try:
if self.commit_hash:
repo = Repo.clone_from(clone_url, self.local_path)
repo.git.checkout(self.commit_hash)
else:
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
except GitCommandError as e:
logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
return False
return True
def _parse_filter_file(self, file_path: str) -> bool:
"""Parses a file with files/directories/extensions to include/exclude.
Lines are expected to be in the format:
# Comment that will be ignored, or
ext:.my-extension, or
file:my-file.py, or
dir:my-directory
"""
with open(file_path, "r") as f:
lines = f.readlines()
parsed_data = {"ext": [], "file": [], "dir": []}
for line in lines:
if line.startswith("#"):
# This is a comment line.
continue
key, value = line.strip().split(":")
if key in parsed_data:
parsed_data[key].append(value)
else:
logging.error("Unrecognized key in line: %s. Skipping.", line)
return parsed_data
def _should_include(self, file_path: str) -> bool:
"""Checks whether the file should be indexed."""
# Exclude symlinks.
if os.path.islink(file_path):
return False
# Exclude hidden files and directories.
if any(part.startswith(".") for part in file_path.split(os.path.sep)):
return False
if not self.inclusions and not self.exclusions:
return True
# Filter based on file extensions, file names and directory names.
_, extension = os.path.splitext(file_path)
extension = extension.lower()
file_name = os.path.basename(file_path)
dirs = os.path.dirname(file_path).split("/")
if self.inclusions:
return (
extension in self.inclusions.get("ext", [])
or file_name in self.inclusions.get("file", [])
or any(d in dirs for d in self.inclusions.get("dir", []))
)
elif self.exclusions:
return (
extension not in self.exclusions.get("ext", [])
and file_name not in self.exclusions.get("file", [])
and all(d not in dirs for d in self.exclusions.get("dir", []))
)
return True
def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
"""Walks the local repository path and yields a tuple of (content, metadata) for each file.
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
Args:
get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only.
"""
# We will keep appending to these files during the iteration, so we need to clear them first.
repo_name = self.repo_id.replace("/", "_")
included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
if os.path.exists(included_log_file):
os.remove(included_log_file)
logging.info("Logging included files at %s", included_log_file)
if os.path.exists(excluded_log_file):
os.remove(excluded_log_file)
logging.info("Logging excluded files at %s", excluded_log_file)
for root, _, files in os.walk(self.local_path):
file_paths = [os.path.join(root, file) for file in files]
included_file_paths = [f for f in file_paths if self._should_include(f)]
with open(included_log_file, "a") as f:
for path in included_file_paths:
f.write(path + "\n")
excluded_file_paths = set(file_paths).difference(set(included_file_paths))
with open(excluded_log_file, "a") as f:
for path in excluded_file_paths:
f.write(path + "\n")
for file_path in included_file_paths:
relative_file_path = file_path[len(self.local_dir) + 1 :]
metadata = {
"file_path": relative_file_path,
"url": self.url_for_file(relative_file_path),
}
if not get_content:
yield metadata
continue
with open(file_path, "r") as f:
try:
contents = f.read()
except UnicodeDecodeError:
logging.warning("Unable to decode file %s. Skipping.", file_path)
continue
yield contents, metadata
def url_for_file(self, file_path: str) -> str:
"""Converts a repository file path to a GitHub link."""
file_path = file_path[len(self.repo_id) + 1 :]
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
def read_file(self, relative_file_path: str) -> str:
"""Reads the content of the file at the given path."""
file_path = os.path.join(self.local_dir, relative_file_path)
with open(file_path, "r") as f:
return f.read()
def from_args(args: Dict):
"""Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
repo_manager = GitHubRepoManager(
repo_id=args.repo_id,
commit_hash=args.commit_hash,
access_token=os.getenv("GITHUB_TOKEN"),
local_dir=args.local_dir,
inclusion_file=args.include,
exclusion_file=args.exclude,
)
success = repo_manager.download()
if not success:
raise ValueError(
f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
"For private repositories, please set the GITHUB_TOKEN variable in your environment."
)
return repo_manager