modular-detector-v2 / scripts /build_index.py
Molbap's picture
Molbap HF Staff
Update app
afd3562 verified
#!/usr/bin/env python3
from __future__ import annotations
import ast
import importlib.util
import json
import os
from pathlib import Path
import numpy as np
import torch
from safetensors.numpy import save_file as safetensors_save
ROOT = Path(__file__).resolve().parent.parent
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
def load_detector(transformers_dir: Path):
module_path = transformers_dir / "utils" / "modular_model_detector.py"
if not module_path.exists():
raise SystemExit(f"Missing modular_model_detector.py at {module_path}")
spec = importlib.util.spec_from_file_location("modular_model_detector", module_path)
if spec is None or spec.loader is None:
raise SystemExit(f"Could not load detector from {module_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None:
segment = ast.get_source_segment(source, node)
if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
start = max(0, node.lineno - 1)
end = node.end_lineno
segment = "\n".join(lines[start:end])
return segment
def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None:
identifiers: list[str] = []
sanitized_sources: list[str] = []
tokens_map: dict[str, list[str]] = {}
modeling_files = sorted(models_root.rglob("modeling_*.py"))
for file_path in modeling_files:
try:
source = file_path.read_text(encoding="utf-8")
except OSError:
continue
try:
tree = ast.parse(source)
except SyntaxError:
continue
lines = source.splitlines()
model_hint = analyzer._infer_model_from_relative_path(file_path)
relative_path = file_path.relative_to(models_root).as_posix()
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
segment = extract_segment(source, node, lines)
if not segment:
continue
identifier = f"{relative_path}:{node.name}"
sanitized = detector._sanitize_for_embedding(segment, model_hint, node.name)
identifiers.append(identifier)
sanitized_sources.append(sanitized)
tokens_map[identifier] = sorted(detector._tokenize(sanitized))
continue
if not isinstance(node, ast.ClassDef):
continue
class_segment = extract_segment(source, node, lines)
class_header = class_segment.splitlines()[0].strip() if class_segment else ""
class_context = class_header
for child in node.body:
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
segment = extract_segment(source, child, lines)
if not segment:
continue
identifier = f"{relative_path}:{node.name}.{child.name}"
combined = f"{class_context}\n{segment}" if class_context else segment
sanitized = detector._sanitize_for_embedding(combined, model_hint, node.name)
identifiers.append(identifier)
sanitized_sources.append(sanitized)
tokens_map[identifier] = sorted(detector._tokenize(sanitized))
if not identifiers:
raise SystemExit("No modeling methods found.")
print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}")
embeddings = analyzer.encode(sanitized_sources)
output_dir.mkdir(parents=True, exist_ok=True)
safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors")
with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file:
json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file:
json.dump(tokens_map, file)
def main() -> None:
transformers_dir = ROOT / "transformers"
if not transformers_dir.exists():
transformers_dir = ROOT / "transformers_repo"
if not transformers_dir.exists():
raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo")
detector = load_detector(transformers_dir)
detector.EMBEDDING_MODEL = EMBEDDING_MODEL
hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT)
analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset)
analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve()
analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32
build_method_index(detector, analyzer, analyzer.models_root, ROOT)
if __name__ == "__main__":
main()