File size: 10,742 Bytes
559dd34
 
 
 
6f4d334
559dd34
6f4d334
559dd34
 
 
 
 
6f4d334
 
 
 
 
 
 
 
 
 
 
 
 
 
559dd34
 
d5c979a
 
 
2730f0a
8b42d65
d5c979a
7db04dd
 
d5c979a
559dd34
 
52c1352
2730f0a
8b42d65
559dd34
7db04dd
 
 
 
559dd34
6f4d334
559dd34
2730f0a
8b42d65
6f4d334
559dd34
 
 
 
6f4d334
 
 
 
 
7db04dd
 
 
 
 
559dd34
 
 
 
57007fe
559dd34
 
 
 
 
 
 
 
 
 
 
 
57007fe
559dd34
 
 
 
 
57007fe
559dd34
 
 
6f4d334
559dd34
 
 
 
 
 
 
 
 
 
 
 
 
 
2730f0a
 
 
 
 
559dd34
57007fe
559dd34
 
 
7db04dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c979a
7db04dd
 
d0e366f
 
7db04dd
d5c979a
 
 
7db04dd
 
 
 
 
 
 
 
 
 
 
 
39898b4
 
 
7db04dd
 
 
39898b4
 
 
7db04dd
d5c979a
 
27dd60a
6f4d334
559dd34
 
 
27dd60a
559dd34
27dd60a
6f4d334
 
 
 
 
7db04dd
6f4d334
 
7db04dd
559dd34
 
 
d5c979a
559dd34
6f4d334
 
 
559dd34
6f4d334
 
 
 
559dd34
 
27dd60a
 
 
 
 
 
 
 
 
 
559dd34
 
 
 
57007fe
559dd34
6f4d334
 
 
559dd34
6f4d334
57007fe
27dd60a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""Utility classes to maniuplate GitHub repositories."""

import logging
import os
from abc import abstractmethod
from functools import cached_property
from typing import Any, Dict, Generator, Tuple

import requests
from git import GitCommandError, Repo


class DataManager:
    def __init__(self, dataset_id: str):
        self.dataset_id = dataset_id

    @abstractmethod
    def download(self) -> bool:
        """Downloads the data from a remote location."""

    @abstractmethod
    def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
        """Yields a tuple of (data, metadata) for each data item in the dataset."""


class GitHubRepoManager(DataManager):
    """Class to manage a local clone of a GitHub repository."""

    def __init__(
        self,
        repo_id: str,
        commit_hash: str = None,
        access_token: str = None,
        local_dir: str = None,
        inclusion_file: str = None,
        exclusion_file: str = None,
    ):
        """
        Args:
            repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
            commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
            access_token: A GitHub access token to use for cloning private repositories. Not needed for public repos.
            local_dir: The local directory where the repository will be cloned.
            inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
                the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
            exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of
                the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
        """
        super().__init__(dataset_id=repo_id)
        self.repo_id = repo_id
        self.commit_hash = commit_hash
        self.access_token = access_token

        self.local_dir = local_dir or "/tmp/"
        if not os.path.exists(self.local_dir):
            os.makedirs(self.local_dir)
        self.local_path = os.path.join(self.local_dir, repo_id)

        self.log_dir = os.path.join(self.local_dir, "logs", repo_id)
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        if inclusion_file and exclusion_file:
            raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")

        self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None
        self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None

    @cached_property
    def is_public(self) -> bool:
        """Checks whether a GitHub repository is publicly visible."""
        response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
        # Note that the response will be 404 for both private and non-existent repos.
        return response.status_code == 200

    @cached_property
    def default_branch(self) -> str:
        """Fetches the default branch of the repository from GitHub."""
        headers = {
            "Accept": "application/vnd.github.v3+json",
        }
        if self.access_token:
            headers["Authorization"] = f"token {self.access_token}"

        response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
        if response.status_code == 200:
            branch = response.json().get("default_branch", "main")
        else:
            # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
            # most common naming for the default branch ("main").
            logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
            branch = "main"
        return branch

    def download(self) -> bool:
        """Clones the repository to the local directory, if it's not already cloned."""
        if os.path.exists(self.local_path):
            # The repository is already cloned.
            return True

        if not self.is_public and not self.access_token:
            raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.")

        if self.access_token:
            clone_url = f"https://{self.access_token}@github.com/{self.repo_id}.git"
        else:
            clone_url = f"https://github.com/{self.repo_id}.git"

        try:
            if self.commit_hash:
                repo = Repo.clone_from(clone_url, self.local_path)
                repo.git.checkout(self.commit_hash)
            else:
                Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
        except GitCommandError as e:
            logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
            return False
        return True

    def _parse_filter_file(self, file_path: str) -> bool:
        """Parses a file with files/directories/extensions to include/exclude.

        Lines are expected to be in the format:
        # Comment that will be ignored, or
        ext:.my-extension, or
        file:my-file.py, or
        dir:my-directory
        """
        with open(file_path, "r") as f:
            lines = f.readlines()

        parsed_data = {"ext": [], "file": [], "dir": []}
        for line in lines:
            if line.startswith("#"):
                # This is a comment line.
                continue
            key, value = line.strip().split(":")
            if key in parsed_data:
                parsed_data[key].append(value)
            else:
                logging.error("Unrecognized key in line: %s. Skipping.", line)

        return parsed_data

    def _should_include(self, file_path: str) -> bool:
        """Checks whether the file should be indexed."""
        # Exclude symlinks.
        if os.path.islink(file_path):
            return False

        # Exclude hidden files and directories.
        if any(part.startswith(".") for part in file_path.split(os.path.sep)):
            return False

        if not self.inclusions and not self.exclusions:
            return True

        # Filter based on file extensions, file names and directory names.
        _, extension = os.path.splitext(file_path)
        extension = extension.lower()
        file_name = os.path.basename(file_path)
        dirs = os.path.dirname(file_path).split("/")

        if self.inclusions:
            return (
                extension in self.inclusions.get("ext", [])
                or file_name in self.inclusions.get("file", [])
                or any(d in dirs for d in self.inclusions.get("dir", []))
            )
        elif self.exclusions:
            return (
                extension not in self.exclusions.get("ext", [])
                and file_name not in self.exclusions.get("file", [])
                and all(d not in dirs for d in self.exclusions.get("dir", []))
            )
        return True

    def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
        """Walks the local repository path and yields a tuple of (content, metadata) for each file.
        The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").

        Args:
            get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only.
        """
        # We will keep appending to these files during the iteration, so we need to clear them first.
        repo_name = self.repo_id.replace("/", "_")
        included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
        excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
        if os.path.exists(included_log_file):
            os.remove(included_log_file)
            logging.info("Logging included files at %s", included_log_file)
        if os.path.exists(excluded_log_file):
            os.remove(excluded_log_file)
            logging.info("Logging excluded files at %s", excluded_log_file)

        for root, _, files in os.walk(self.local_path):
            file_paths = [os.path.join(root, file) for file in files]
            included_file_paths = [f for f in file_paths if self._should_include(f)]

            with open(included_log_file, "a") as f:
                for path in included_file_paths:
                    f.write(path + "\n")

            excluded_file_paths = set(file_paths).difference(set(included_file_paths))
            with open(excluded_log_file, "a") as f:
                for path in excluded_file_paths:
                    f.write(path + "\n")

            for file_path in included_file_paths:
                relative_file_path = file_path[len(self.local_dir) + 1 :]
                metadata = {
                    "file_path": relative_file_path,
                    "url": self.url_for_file(relative_file_path),
                }

                if not get_content:
                    yield metadata
                    continue

                with open(file_path, "r") as f:
                    try:
                        contents = f.read()
                    except UnicodeDecodeError:
                        logging.warning("Unable to decode file %s. Skipping.", file_path)
                        continue
                    yield contents, metadata

    def url_for_file(self, file_path: str) -> str:
        """Converts a repository file path to a GitHub link."""
        file_path = file_path[len(self.repo_id) + 1 :]
        return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"

    def read_file(self, relative_file_path: str) -> str:
        """Reads the content of the file at the given path."""
        file_path = os.path.join(self.local_dir, relative_file_path)
        with open(file_path, "r") as f:
            return f.read()

    def from_args(args: Dict):
        """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
        repo_manager = GitHubRepoManager(
            repo_id=args.repo_id,
            commit_hash=args.commit_hash,
            access_token=os.getenv("GITHUB_TOKEN"),
            local_dir=args.local_dir,
            inclusion_file=args.include,
            exclusion_file=args.exclude,
        )
        success = repo_manager.download()
        if not success:
            raise ValueError(
                f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
                "For private repositories, please set the GITHUB_TOKEN variable in your environment."
            )
        return repo_manager