Spaces:
Running
Running
| """Utility classes to maniuplate GitHub repositories.""" | |
| import logging | |
| import os | |
| from abc import abstractmethod | |
| from functools import cached_property | |
| from typing import Any, Dict, Generator, Tuple | |
| import requests | |
| from git import GitCommandError, Repo | |
| class DataManager: | |
| def __init__(self, dataset_id: str): | |
| self.dataset_id = dataset_id | |
| def download(self) -> bool: | |
| """Downloads the data from a remote location.""" | |
| def walk(self) -> Generator[Tuple[Any, Dict], None, None]: | |
| """Yields a tuple of (data, metadata) for each data item in the dataset.""" | |
| class GitHubRepoManager(DataManager): | |
| """Class to manage a local clone of a GitHub repository.""" | |
| def __init__( | |
| self, | |
| repo_id: str, | |
| commit_hash: str = None, | |
| access_token: str = None, | |
| local_dir: str = None, | |
| inclusion_file: str = None, | |
| exclusion_file: str = None, | |
| ): | |
| """ | |
| Args: | |
| repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage". | |
| commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo. | |
| access_token: A GitHub access token to use for cloning private repositories. Not needed for public repos. | |
| local_dir: The local directory where the repository will be cloned. | |
| inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of | |
| the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory". | |
| exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of | |
| the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory". | |
| """ | |
| super().__init__(dataset_id=repo_id) | |
| self.repo_id = repo_id | |
| self.commit_hash = commit_hash | |
| self.access_token = access_token | |
| self.local_dir = local_dir or "/tmp/" | |
| if not os.path.exists(self.local_dir): | |
| os.makedirs(self.local_dir) | |
| self.local_path = os.path.join(self.local_dir, repo_id) | |
| self.log_dir = os.path.join(self.local_dir, "logs", repo_id) | |
| if not os.path.exists(self.log_dir): | |
| os.makedirs(self.log_dir) | |
| if inclusion_file and exclusion_file: | |
| raise ValueError("Only one of inclusion_file or exclusion_file should be provided.") | |
| self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None | |
| self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None | |
| def is_public(self) -> bool: | |
| """Checks whether a GitHub repository is publicly visible.""" | |
| response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10) | |
| # Note that the response will be 404 for both private and non-existent repos. | |
| return response.status_code == 200 | |
| def default_branch(self) -> str: | |
| """Fetches the default branch of the repository from GitHub.""" | |
| headers = { | |
| "Accept": "application/vnd.github.v3+json", | |
| } | |
| if self.access_token: | |
| headers["Authorization"] = f"token {self.access_token}" | |
| response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers) | |
| if response.status_code == 200: | |
| branch = response.json().get("default_branch", "main") | |
| else: | |
| # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the | |
| # most common naming for the default branch ("main"). | |
| logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}") | |
| branch = "main" | |
| return branch | |
| def download(self) -> bool: | |
| """Clones the repository to the local directory, if it's not already cloned.""" | |
| if os.path.exists(self.local_path): | |
| # The repository is already cloned. | |
| return True | |
| if not self.is_public and not self.access_token: | |
| raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.") | |
| if self.access_token: | |
| clone_url = f"https://{self.access_token}@github.com/{self.repo_id}.git" | |
| else: | |
| clone_url = f"https://github.com/{self.repo_id}.git" | |
| try: | |
| if self.commit_hash: | |
| repo = Repo.clone_from(clone_url, self.local_path) | |
| repo.git.checkout(self.commit_hash) | |
| else: | |
| Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True) | |
| except GitCommandError as e: | |
| logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e) | |
| return False | |
| return True | |
| def _parse_filter_file(self, file_path: str) -> bool: | |
| """Parses a file with files/directories/extensions to include/exclude. | |
| Lines are expected to be in the format: | |
| # Comment that will be ignored, or | |
| ext:.my-extension, or | |
| file:my-file.py, or | |
| dir:my-directory | |
| """ | |
| with open(file_path, "r") as f: | |
| lines = f.readlines() | |
| parsed_data = {"ext": [], "file": [], "dir": []} | |
| for line in lines: | |
| if line.startswith("#"): | |
| # This is a comment line. | |
| continue | |
| key, value = line.strip().split(":") | |
| if key in parsed_data: | |
| parsed_data[key].append(value) | |
| else: | |
| logging.error("Unrecognized key in line: %s. Skipping.", line) | |
| return parsed_data | |
| def _should_include(self, file_path: str) -> bool: | |
| """Checks whether the file should be indexed.""" | |
| # Exclude symlinks. | |
| if os.path.islink(file_path): | |
| return False | |
| # Exclude hidden files and directories. | |
| if any(part.startswith(".") for part in file_path.split(os.path.sep)): | |
| return False | |
| if not self.inclusions and not self.exclusions: | |
| return True | |
| # Filter based on file extensions, file names and directory names. | |
| _, extension = os.path.splitext(file_path) | |
| extension = extension.lower() | |
| file_name = os.path.basename(file_path) | |
| dirs = os.path.dirname(file_path).split("/") | |
| if self.inclusions: | |
| return ( | |
| extension in self.inclusions.get("ext", []) | |
| or file_name in self.inclusions.get("file", []) | |
| or any(d in dirs for d in self.inclusions.get("dir", [])) | |
| ) | |
| elif self.exclusions: | |
| return ( | |
| extension not in self.exclusions.get("ext", []) | |
| and file_name not in self.exclusions.get("file", []) | |
| and all(d not in dirs for d in self.exclusions.get("dir", [])) | |
| ) | |
| return True | |
| def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]: | |
| """Walks the local repository path and yields a tuple of (content, metadata) for each file. | |
| The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py"). | |
| Args: | |
| get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only. | |
| """ | |
| # We will keep appending to these files during the iteration, so we need to clear them first. | |
| repo_name = self.repo_id.replace("/", "_") | |
| included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt") | |
| excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt") | |
| if os.path.exists(included_log_file): | |
| os.remove(included_log_file) | |
| logging.info("Logging included files at %s", included_log_file) | |
| if os.path.exists(excluded_log_file): | |
| os.remove(excluded_log_file) | |
| logging.info("Logging excluded files at %s", excluded_log_file) | |
| for root, _, files in os.walk(self.local_path): | |
| file_paths = [os.path.join(root, file) for file in files] | |
| included_file_paths = [f for f in file_paths if self._should_include(f)] | |
| with open(included_log_file, "a") as f: | |
| for path in included_file_paths: | |
| f.write(path + "\n") | |
| excluded_file_paths = set(file_paths).difference(set(included_file_paths)) | |
| with open(excluded_log_file, "a") as f: | |
| for path in excluded_file_paths: | |
| f.write(path + "\n") | |
| for file_path in included_file_paths: | |
| relative_file_path = file_path[len(self.local_dir) + 1 :] | |
| metadata = { | |
| "file_path": relative_file_path, | |
| "url": self.url_for_file(relative_file_path), | |
| } | |
| if not get_content: | |
| yield metadata | |
| continue | |
| with open(file_path, "r") as f: | |
| try: | |
| contents = f.read() | |
| except UnicodeDecodeError: | |
| logging.warning("Unable to decode file %s. Skipping.", file_path) | |
| continue | |
| yield contents, metadata | |
| def url_for_file(self, file_path: str) -> str: | |
| """Converts a repository file path to a GitHub link.""" | |
| file_path = file_path[len(self.repo_id) + 1 :] | |
| return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}" | |
| def read_file(self, relative_file_path: str) -> str: | |
| """Reads the content of the file at the given path.""" | |
| file_path = os.path.join(self.local_dir, relative_file_path) | |
| with open(file_path, "r") as f: | |
| return f.read() | |
| def from_args(args: Dict): | |
| """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository.""" | |
| repo_manager = GitHubRepoManager( | |
| repo_id=args.repo_id, | |
| commit_hash=args.commit_hash, | |
| access_token=os.getenv("GITHUB_TOKEN"), | |
| local_dir=args.local_dir, | |
| inclusion_file=args.include, | |
| exclusion_file=args.exclude, | |
| ) | |
| success = repo_manager.download() | |
| if not success: | |
| raise ValueError( | |
| f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. " | |
| "For private repositories, please set the GITHUB_TOKEN variable in your environment." | |
| ) | |
| return repo_manager | |