#!/usr/bin/env python3 from __future__ import annotations import ast import importlib.util import json import os from pathlib import Path import numpy as np import torch from safetensors.numpy import save_file as safetensors_save ROOT = Path(__file__).resolve().parent.parent EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" def load_detector(transformers_dir: Path): module_path = transformers_dir / "utils" / "modular_model_detector.py" if not module_path.exists(): raise SystemExit(f"Missing modular_model_detector.py at {module_path}") spec = importlib.util.spec_from_file_location("modular_model_detector", module_path) if spec is None or spec.loader is None: raise SystemExit(f"Could not load detector from {module_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None: segment = ast.get_source_segment(source, node) if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): start = max(0, node.lineno - 1) end = node.end_lineno segment = "\n".join(lines[start:end]) return segment def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None: identifiers: list[str] = [] sanitized_sources: list[str] = [] tokens_map: dict[str, list[str]] = {} modeling_files = sorted(models_root.rglob("modeling_*.py")) for file_path in modeling_files: try: source = file_path.read_text(encoding="utf-8") except OSError: continue try: tree = ast.parse(source) except SyntaxError: continue lines = source.splitlines() model_hint = analyzer._infer_model_from_relative_path(file_path) relative_path = file_path.relative_to(models_root).as_posix() for node in ast.iter_child_nodes(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): segment = extract_segment(source, node, lines) if not segment: continue identifier = f"{relative_path}:{node.name}" sanitized = detector._sanitize_for_embedding(segment, model_hint, node.name) identifiers.append(identifier) sanitized_sources.append(sanitized) tokens_map[identifier] = sorted(detector._tokenize(sanitized)) continue if not isinstance(node, ast.ClassDef): continue class_segment = extract_segment(source, node, lines) class_header = class_segment.splitlines()[0].strip() if class_segment else "" class_context = class_header for child in node.body: if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): continue segment = extract_segment(source, child, lines) if not segment: continue identifier = f"{relative_path}:{node.name}.{child.name}" combined = f"{class_context}\n{segment}" if class_context else segment sanitized = detector._sanitize_for_embedding(combined, model_hint, node.name) identifiers.append(identifier) sanitized_sources.append(sanitized) tokens_map[identifier] = sorted(detector._tokenize(sanitized)) if not identifiers: raise SystemExit("No modeling methods found.") print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}") embeddings = analyzer.encode(sanitized_sources) output_dir.mkdir(parents=True, exist_ok=True) safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors") with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file: json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file: json.dump(tokens_map, file) def main() -> None: transformers_dir = ROOT / "transformers" if not transformers_dir.exists(): transformers_dir = ROOT / "transformers_repo" if not transformers_dir.exists(): raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo") detector = load_detector(transformers_dir) detector.EMBEDDING_MODEL = EMBEDDING_MODEL hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT) analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset) analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve() analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32 build_method_index(detector, analyzer, analyzer.models_root, ROOT) if __name__ == "__main__": main()