Molbap HF Staff commited on
Commit
7a4622c
·
verified ·
1 Parent(s): a5519b2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/build_index.py +30 -108
scripts/build_index.py CHANGED
@@ -1,20 +1,7 @@
1
  #!/usr/bin/env python3
2
 
3
- """
4
- Build local similarity indexes from a Transformers checkout.
5
-
6
- This script reuses core components from `utils/modular_model_detector.py` in a local
7
- Transformers clone, including the embedding pipeline and sanitization/tokenization.
8
-
9
- Outputs are written to the chosen output directory (default: repo root):
10
- - embeddings*.safetensors
11
- - code_index_map*.json
12
- - code_index_tokens*.json
13
- """
14
-
15
  from __future__ import annotations
16
 
17
- import argparse
18
  import ast
19
  import importlib.util
20
  import json
@@ -22,29 +9,26 @@ import os
22
  from pathlib import Path
23
 
24
  import numpy as np
 
25
  from safetensors.numpy import save_file as safetensors_save
26
- try:
27
- from tqdm import tqdm
28
- except ImportError: # pragma: no cover - optional dependency
29
- def tqdm(iterable, **_kwargs):
30
- return iterable
31
 
32
  ROOT = Path(__file__).resolve().parent.parent
33
 
34
 
35
- def _load_detector_module(transformers_dir: Path):
36
  module_path = transformers_dir / "utils" / "modular_model_detector.py"
37
  if not module_path.exists():
38
- raise SystemExit(f"Expected modular_model_detector at {module_path}")
39
  spec = importlib.util.spec_from_file_location("modular_model_detector", module_path)
40
  if spec is None or spec.loader is None:
41
- raise SystemExit(f"Could not load modular_model_detector from {module_path}")
42
  module = importlib.util.module_from_spec(spec)
43
  spec.loader.exec_module(module)
44
  return module
45
 
46
 
47
- def _extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None:
48
  segment = ast.get_source_segment(source, node)
49
  if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
50
  start = max(0, node.lineno - 1)
@@ -53,35 +37,13 @@ def _extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None
53
  return segment
54
 
55
 
56
- def _collect_definitions(analyzer, models_root: Path) -> tuple[list[str], list[str], dict[str, list[str]]]:
57
  identifiers: list[str] = []
58
  sanitized_sources: list[str] = []
59
  tokens_map: dict[str, list[str]] = {}
60
 
61
  modeling_files = sorted(models_root.rglob("modeling_*.py"))
62
- print(f"Parsing {len(modeling_files)} modeling files (definition granularity)...")
63
- for file_path in tqdm(modeling_files, desc="parse definitions", unit="file"):
64
- try:
65
- definitions_raw, definitions_sanitized, definitions_tokens, _ = analyzer._extract_definitions(
66
- file_path, models_root, analyzer._infer_model_from_relative_path(file_path)
67
- )
68
- except (OSError, SyntaxError):
69
- continue
70
- for identifier, sanitized in definitions_sanitized.items():
71
- identifiers.append(identifier)
72
- sanitized_sources.append(sanitized)
73
- tokens_map[identifier] = definitions_tokens[identifier]
74
- return identifiers, sanitized_sources, tokens_map
75
-
76
-
77
- def _collect_methods(detector, analyzer, models_root: Path) -> tuple[list[str], list[str], dict[str, list[str]]]:
78
- identifiers: list[str] = []
79
- sanitized_sources: list[str] = []
80
- tokens_map: dict[str, list[str]] = {}
81
-
82
- modeling_files = sorted(models_root.rglob("modeling_*.py"))
83
- print(f"Parsing {len(modeling_files)} modeling files (method granularity)...")
84
- for file_path in tqdm(modeling_files, desc="parse methods", unit="file"):
85
  try:
86
  source = file_path.read_text(encoding="utf-8")
87
  except OSError:
@@ -97,7 +59,7 @@ def _collect_methods(detector, analyzer, models_root: Path) -> tuple[list[str],
97
 
98
  for node in ast.iter_child_nodes(tree):
99
  if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
100
- segment = _extract_segment(source, node, lines)
101
  if not segment:
102
  continue
103
  identifier = f"{relative_path}:{node.name}"
@@ -110,7 +72,7 @@ def _collect_methods(detector, analyzer, models_root: Path) -> tuple[list[str],
110
  if not isinstance(node, ast.ClassDef):
111
  continue
112
 
113
- class_segment = _extract_segment(source, node, lines)
114
  class_header = class_segment.splitlines()[0].strip() if class_segment else ""
115
  class_docstring = ast.get_docstring(node)
116
  class_context = class_header
@@ -121,7 +83,7 @@ def _collect_methods(detector, analyzer, models_root: Path) -> tuple[list[str],
121
  for child in node.body:
122
  if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
123
  continue
124
- segment = _extract_segment(source, child, lines)
125
  if not segment:
126
  continue
127
  identifier = f"{relative_path}:{node.name}.{child.name}"
@@ -130,74 +92,34 @@ def _collect_methods(detector, analyzer, models_root: Path) -> tuple[list[str],
130
  identifiers.append(identifier)
131
  sanitized_sources.append(sanitized)
132
  tokens_map[identifier] = sorted(detector._tokenize(sanitized))
133
- return identifiers, sanitized_sources, tokens_map
134
-
135
-
136
- def _write_index(output_dir: Path, granularity: str, identifiers: list[str], embeddings: np.ndarray, tokens: dict) -> None:
137
- output_dir.mkdir(parents=True, exist_ok=True)
138
- if granularity == "method":
139
- emb_name = "embeddings_methods.safetensors"
140
- map_name = "code_index_map_methods.json"
141
- tok_name = "code_index_tokens_methods.json"
142
- else:
143
- emb_name = "embeddings.safetensors"
144
- map_name = "code_index_map.json"
145
- tok_name = "code_index_tokens.json"
146
-
147
- safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / emb_name)
148
- with open(output_dir / map_name, "w", encoding="utf-8") as file:
149
- json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
150
- with open(output_dir / tok_name, "w", encoding="utf-8") as file:
151
- json.dump(tokens, file)
152
-
153
 
154
- def build_index(detector, analyzer, transformers_dir: Path, output_dir: Path, granularity: str) -> None:
155
- models_root = transformers_dir / "src" / "transformers" / "models"
156
- if not models_root.exists():
157
- raise SystemExit(f"Expected models directory at {models_root}")
158
-
159
- if granularity == "method":
160
- identifiers, sanitized_sources, tokens_map = _collect_methods(detector, analyzer, models_root)
161
- else:
162
- identifiers, sanitized_sources, tokens_map = _collect_definitions(analyzer, models_root)
163
  if not identifiers:
164
- raise SystemExit("No modeling definitions found to index.")
165
 
166
- print(f"Encoding {len(identifiers)} definitions ({granularity}) with {detector.EMBEDDING_MODEL}")
167
  embeddings = analyzer.encode(sanitized_sources)
168
- _write_index(output_dir, granularity, identifiers, embeddings, tokens_map)
169
- print(f"Wrote index ({granularity}) to {output_dir}")
 
 
 
 
 
170
 
171
 
172
  def main() -> None:
173
- parser = argparse.ArgumentParser(description="Build modular model graph indexes locally.")
174
- parser.add_argument(
175
- "--transformers-dir",
176
- type=Path,
177
- default=ROOT / "transformers",
178
- help="Path to a transformers git clone (expects src/transformers/models inside).",
179
- )
180
- parser.add_argument(
181
- "--output-dir",
182
- type=Path,
183
- default=ROOT,
184
- help="Where to place the generated index files (default: repo root).",
185
- )
186
- parser.add_argument(
187
- "--granularity",
188
- choices=["definition", "method", "both"],
189
- default="both",
190
- help="Which index to build. 'both' runs definition + method sequentially.",
191
- )
192
- args = parser.parse_args()
193
-
194
- detector = _load_detector_module(args.transformers_dir.resolve())
195
  hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT)
196
  analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset)
197
- analyzer.models_root = (args.transformers_dir / "src" / "transformers" / "models").resolve()
198
- targets = ["definition", "method"] if args.granularity == "both" else [args.granularity]
199
- for target in targets:
200
- build_index(detector, analyzer, args.transformers_dir.resolve(), args.output_dir.resolve(), target)
201
 
202
 
203
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from __future__ import annotations
4
 
 
5
  import ast
6
  import importlib.util
7
  import json
 
9
  from pathlib import Path
10
 
11
  import numpy as np
12
+ import torch
13
  from safetensors.numpy import save_file as safetensors_save
14
+
 
 
 
 
15
 
16
  ROOT = Path(__file__).resolve().parent.parent
17
 
18
 
19
+ def load_detector(transformers_dir: Path):
20
  module_path = transformers_dir / "utils" / "modular_model_detector.py"
21
  if not module_path.exists():
22
+ raise SystemExit(f"Missing modular_model_detector.py at {module_path}")
23
  spec = importlib.util.spec_from_file_location("modular_model_detector", module_path)
24
  if spec is None or spec.loader is None:
25
+ raise SystemExit(f"Could not load detector from {module_path}")
26
  module = importlib.util.module_from_spec(spec)
27
  spec.loader.exec_module(module)
28
  return module
29
 
30
 
31
+ def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None:
32
  segment = ast.get_source_segment(source, node)
33
  if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
34
  start = max(0, node.lineno - 1)
 
37
  return segment
38
 
39
 
40
+ def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None:
41
  identifiers: list[str] = []
42
  sanitized_sources: list[str] = []
43
  tokens_map: dict[str, list[str]] = {}
44
 
45
  modeling_files = sorted(models_root.rglob("modeling_*.py"))
46
+ for file_path in modeling_files:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
  source = file_path.read_text(encoding="utf-8")
49
  except OSError:
 
59
 
60
  for node in ast.iter_child_nodes(tree):
61
  if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
62
+ segment = extract_segment(source, node, lines)
63
  if not segment:
64
  continue
65
  identifier = f"{relative_path}:{node.name}"
 
72
  if not isinstance(node, ast.ClassDef):
73
  continue
74
 
75
+ class_segment = extract_segment(source, node, lines)
76
  class_header = class_segment.splitlines()[0].strip() if class_segment else ""
77
  class_docstring = ast.get_docstring(node)
78
  class_context = class_header
 
83
  for child in node.body:
84
  if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
85
  continue
86
+ segment = extract_segment(source, child, lines)
87
  if not segment:
88
  continue
89
  identifier = f"{relative_path}:{node.name}.{child.name}"
 
92
  identifiers.append(identifier)
93
  sanitized_sources.append(sanitized)
94
  tokens_map[identifier] = sorted(detector._tokenize(sanitized))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
96
  if not identifiers:
97
+ raise SystemExit("No modeling methods found.")
98
 
99
+ print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}")
100
  embeddings = analyzer.encode(sanitized_sources)
101
+
102
+ output_dir.mkdir(parents=True, exist_ok=True)
103
+ safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors")
104
+ with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file:
105
+ json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
106
+ with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file:
107
+ json.dump(tokens_map, file)
108
 
109
 
110
  def main() -> None:
111
+ transformers_dir = ROOT / "transformers"
112
+ if not transformers_dir.exists():
113
+ transformers_dir = ROOT / "transformers_repo"
114
+ if not transformers_dir.exists():
115
+ raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo")
116
+
117
+ detector = load_detector(transformers_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT)
119
  analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset)
120
+ analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve()
121
+ analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32
122
+ build_method_index(detector, analyzer, analyzer.models_root, ROOT)
 
123
 
124
 
125
  if __name__ == "__main__":