juliaturc commited on
Commit
7db04dd
·
1 Parent(s): 73702d3

Add option to exclude based on file and folder name, not just extension.

Browse files
repo2vec/data_manager.py CHANGED
@@ -30,13 +30,17 @@ class GitHubRepoManager(DataManager):
30
  self,
31
  repo_id: str,
32
  local_dir: str = None,
33
- included_extensions: set = None,
34
- excluded_extensions: set = None,
35
  ):
36
  """
37
  Args:
38
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
39
  local_dir: The local directory where the repository will be cloned.
 
 
 
 
40
  """
41
  super().__init__(dataset_id=repo_id)
42
  self.repo_id = repo_id
@@ -51,8 +55,12 @@ class GitHubRepoManager(DataManager):
51
  os.makedirs(self.log_dir)
52
 
53
  self.access_token = os.getenv("GITHUB_TOKEN")
54
- self.included_extensions = included_extensions
55
- self.excluded_extensions = excluded_extensions
 
 
 
 
56
 
57
  @cached_property
58
  def is_public(self) -> bool:
@@ -101,19 +109,62 @@ class GitHubRepoManager(DataManager):
101
  return False
102
  return True
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def _should_include(self, file_path: str) -> bool:
105
- """Checks whether the file should be indexed, based on the included and excluded extensions."""
 
106
  if os.path.islink(file_path):
107
  return False
108
- _, extension = os.path.splitext(file_path)
109
- extension = extension.lower()
110
- if self.included_extensions and extension not in self.included_extensions:
111
- return False
112
- if self.excluded_extensions and extension in self.excluded_extensions:
113
- return False
114
  # Exclude hidden files and directories.
115
  if any(part.startswith(".") for part in file_path.split(os.path.sep)):
116
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  return True
118
 
119
  def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
@@ -130,8 +181,10 @@ class GitHubRepoManager(DataManager):
130
  excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
131
  if os.path.exists(included_log_file):
132
  os.remove(included_log_file)
 
133
  if os.path.exists(excluded_log_file):
134
  os.remove(excluded_log_file)
 
135
 
136
  for root, _, files in os.walk(self.local_path):
137
  file_paths = [os.path.join(root, file) for file in files]
 
30
  self,
31
  repo_id: str,
32
  local_dir: str = None,
33
+ inclusion_file: str = None,
34
+ exclusion_file: str = None,
35
  ):
36
  """
37
  Args:
38
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
39
  local_dir: The local directory where the repository will be cloned.
40
+ inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
41
+ the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
42
+ exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of
43
+ the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
44
  """
45
  super().__init__(dataset_id=repo_id)
46
  self.repo_id = repo_id
 
55
  os.makedirs(self.log_dir)
56
 
57
  self.access_token = os.getenv("GITHUB_TOKEN")
58
+
59
+ if inclusion_file and exclusion_file:
60
+ raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")
61
+
62
+ self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None
63
+ self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None
64
 
65
  @cached_property
66
  def is_public(self) -> bool:
 
109
  return False
110
  return True
111
 
112
+ def _parse_filter_file(self, file_path: str) -> bool:
113
+ """Parses a file with files/directories/extensions to include/exclude.
114
+
115
+ Lines are expected to be in the format:
116
+ # Comment that will be ignored, or
117
+ ext:.my-extension, or
118
+ file:my-file.py, or
119
+ dir:my-directory
120
+ """
121
+ with open(file_path, "r") as f:
122
+ lines = f.readlines()
123
+
124
+ parsed_data = {"ext": [], "file": [], "dir": []}
125
+ for line in lines:
126
+ if line.startswith("#"):
127
+ # This is a comment line.
128
+ continue
129
+ key, value = line.strip().split(":")
130
+ if key in parsed_data:
131
+ parsed_data[key].append(value)
132
+ else:
133
+ logging.error("Unrecognized key in line: %s. Skipping.", line)
134
+
135
+ return parsed_data
136
+
137
  def _should_include(self, file_path: str) -> bool:
138
+ """Checks whether the file should be indexed."""
139
+ # Exclude symlinks.
140
  if os.path.islink(file_path):
141
  return False
142
+
 
 
 
 
 
143
  # Exclude hidden files and directories.
144
  if any(part.startswith(".") for part in file_path.split(os.path.sep)):
145
  return False
146
+
147
+ if not self.inclusions and not self.exclusions:
148
+ return True
149
+
150
+ # Filter based on file extensions, file names and directory names.
151
+ _, extension = os.path.splitext(file_path)
152
+ extension = extension.lower()
153
+ file_name = os.path.basename(file_path)
154
+ dirs = os.path.dirname(file_path).split("/")
155
+
156
+ if self.inclusions:
157
+ return (
158
+ extension in self.inclusions.get("ext", []) or
159
+ file_name in self.inclusions.get("file", []) or
160
+ any(d in dirs for d in self.inclusions.get("dir", []))
161
+ )
162
+ elif self.exclusions:
163
+ return (
164
+ extension not in self.exclusions.get("ext", []) and
165
+ file_name not in self.exclusions.get("file", []) and
166
+ all(d not in dirs for d in self.exclusions.get("dir", []))
167
+ )
168
  return True
169
 
170
  def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
 
181
  excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
182
  if os.path.exists(included_log_file):
183
  os.remove(included_log_file)
184
+ logging.info("Logging included files at %s", included_log_file)
185
  if os.path.exists(excluded_log_file):
186
  os.remove(excluded_log_file)
187
+ logging.info("Logging excluded files at %s", excluded_log_file)
188
 
189
  for root, _, files in os.walk(self.local_path):
190
  file_paths = [os.path.join(root, file) for file in files]
repo2vec/embedder.py CHANGED
@@ -149,7 +149,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
149
  metadata={},
150
  )
151
  except Exception as e:
152
- print(f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}")
153
  return None
154
 
155
  @staticmethod
 
149
  metadata={},
150
  )
151
  except Exception as e:
152
+ logging.error(f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}")
153
  return None
154
 
155
  @staticmethod
repo2vec/index.py CHANGED
@@ -30,11 +30,6 @@ OPENAI_DEFAULT_EMBEDDING_SIZE = {
30
  }
31
 
32
 
33
- def _read_extensions(path):
34
- with open(path, "r") as f:
35
- return {line.strip().lower() for line in f}
36
-
37
-
38
  def main():
39
  parser = argparse.ArgumentParser(description="Batch-embeds a GitHub repository and its issues.")
40
  parser.add_argument("repo_id", help="The ID of the repository to index")
@@ -163,15 +158,12 @@ def main():
163
  # Index the repository.
164
  repo_embedder = None
165
  if args.index_repo:
166
- included_extensions = _read_extensions(args.include) if args.include else None
167
- excluded_extensions = _read_extensions(args.exclude) if args.exclude else None
168
-
169
  logging.info("Cloning the repository...")
170
  repo_manager = GitHubRepoManager(
171
  args.repo_id,
172
  local_dir=args.local_dir,
173
- included_extensions=included_extensions,
174
- excluded_extensions=excluded_extensions,
175
  )
176
  repo_manager.download()
177
  logging.info("Embedding the repo...")
 
30
  }
31
 
32
 
 
 
 
 
 
33
  def main():
34
  parser = argparse.ArgumentParser(description="Batch-embeds a GitHub repository and its issues.")
35
  parser.add_argument("repo_id", help="The ID of the repository to index")
 
158
  # Index the repository.
159
  repo_embedder = None
160
  if args.index_repo:
 
 
 
161
  logging.info("Cloning the repository...")
162
  repo_manager = GitHubRepoManager(
163
  args.repo_id,
164
  local_dir=args.local_dir,
165
+ inclusion_file=args.include,
166
+ exclusion_file=args.exclude,
167
  )
168
  repo_manager.download()
169
  logging.info("Embedding the repo...")
repo2vec/sample-exclude.txt CHANGED
@@ -1,64 +1,89 @@
1
- .avi
2
- .bazel
3
- .bin
4
- .binpb
5
- .bmp
6
- .crt
7
- .css
8
- .csv
9
- .dat
10
- .db
11
- .duckdb
12
- .eot
13
- .exe
14
- .gif
15
- .gguf
16
- .glb
17
- .gz
18
- .ico
19
- .icns
20
- .inp
21
- .ipynb
22
- .isl
23
- .jar
24
- .jpeg
25
- .jpg
26
- .json
27
- .key
28
- .lock
29
- .mo
30
- .model
31
- .mov
32
- .mp3
33
- .mp4
34
- .otf
35
- .out
36
- .Packages
37
- .pb
38
- .pdf
39
- .pem
40
- .pickle
41
- .png
42
- .pt
43
- .ptl
44
- .s
45
- .so
46
- .sqlite
47
- .stl
48
- .sum
49
- .svg
50
- .tar
51
- .th
52
- .tgz
53
- .toml
54
- .ts-fixture
55
- .ttf
56
- .wav
57
- .webp
58
- .wmv
59
- .woff
60
- .woff2
61
- .xml
62
- .yaml
63
- .yml
64
- .zip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This list tends to be overly-aggressive. We're assuming by default devs are most interested in code files, not configs.
2
+ dir:deprecated
3
+ dir:docker
4
+ dir:downgrades
5
+ dir:fixtures
6
+ dir:integration-tests
7
+ dir:legacy
8
+ dir:library-tests
9
+ dir:logo
10
+ dir:logs
11
+ dir:node_modules
12
+ dir:old-change-notes
13
+ dir:test
14
+ dir:testdata
15
+ dir:tests
16
+ dir:third_party
17
+ dir:upgrades
18
+ dir:vendor
19
+ ext:.Packages
20
+ ext:.avi
21
+ ext:.bazel
22
+ ext:.bin
23
+ ext:.binpb
24
+ ext:.bmp
25
+ ext:.crt
26
+ ext:.css
27
+ ext:.csv
28
+ ext:.dat
29
+ ext:.db
30
+ ext:.duckdb
31
+ ext:.eot
32
+ ext:.exe
33
+ ext:.gguf
34
+ ext:.gif
35
+ ext:.glb
36
+ ext:.gz
37
+ ext:.icns
38
+ ext:.ico
39
+ ext:.inp
40
+ ext:.ipynb
41
+ ext:.isl
42
+ ext:.jar
43
+ ext:.jpeg
44
+ ext:.jpg
45
+ ext:.json
46
+ ext:.key
47
+ ext:.lock
48
+ ext:.mo
49
+ ext:.model
50
+ ext:.mov
51
+ ext:.mp3
52
+ ext:.mp4
53
+ ext:.otf
54
+ ext:.out
55
+ ext:.pb
56
+ ext:.pdf
57
+ ext:.pem
58
+ ext:.pickle
59
+ ext:.png
60
+ ext:.pt
61
+ ext:.ptl
62
+ ext:.s
63
+ ext:.so
64
+ ext:.sqlite
65
+ ext:.stl
66
+ ext:.sum
67
+ ext:.svg
68
+ ext:.tar
69
+ ext:.tgz
70
+ ext:.th
71
+ ext:.toml
72
+ ext:.ts-fixture
73
+ ext:.ttf
74
+ ext:.wav
75
+ ext:.webp
76
+ ext:.wmv
77
+ ext:.woff
78
+ ext:.woff2
79
+ ext:.xml
80
+ ext:.yaml
81
+ ext:.yml
82
+ ext:.zip
83
+ file:CODE_OF_CONDUCT.md
84
+ file:CONTRIBUTING.md
85
+ file:Dockerfile
86
+ file:__init__.py
87
+ file:code-of-conduct.md
88
+ file:conftest.py
89
+ file:package-lock.json