Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import ast | |
| import importlib.util | |
| import json | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from safetensors.numpy import save_file as safetensors_save | |
| ROOT = Path(__file__).resolve().parent.parent | |
| EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" | |
| def load_detector(transformers_dir: Path): | |
| module_path = transformers_dir / "utils" / "modular_model_detector.py" | |
| if not module_path.exists(): | |
| raise SystemExit(f"Missing modular_model_detector.py at {module_path}") | |
| spec = importlib.util.spec_from_file_location("modular_model_detector", module_path) | |
| if spec is None or spec.loader is None: | |
| raise SystemExit(f"Could not load detector from {module_path}") | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| return module | |
| def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None: | |
| segment = ast.get_source_segment(source, node) | |
| if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): | |
| start = max(0, node.lineno - 1) | |
| end = node.end_lineno | |
| segment = "\n".join(lines[start:end]) | |
| return segment | |
| def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None: | |
| identifiers: list[str] = [] | |
| sanitized_sources: list[str] = [] | |
| tokens_map: dict[str, list[str]] = {} | |
| modeling_files = sorted(models_root.rglob("modeling_*.py")) | |
| for file_path in modeling_files: | |
| try: | |
| source = file_path.read_text(encoding="utf-8") | |
| except OSError: | |
| continue | |
| try: | |
| tree = ast.parse(source) | |
| except SyntaxError: | |
| continue | |
| lines = source.splitlines() | |
| model_hint = analyzer._infer_model_from_relative_path(file_path) | |
| relative_path = file_path.relative_to(models_root).as_posix() | |
| for node in ast.iter_child_nodes(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| segment = extract_segment(source, node, lines) | |
| if not segment: | |
| continue | |
| identifier = f"{relative_path}:{node.name}" | |
| sanitized = detector._sanitize_for_embedding(segment, model_hint, node.name) | |
| identifiers.append(identifier) | |
| sanitized_sources.append(sanitized) | |
| tokens_map[identifier] = sorted(detector._tokenize(sanitized)) | |
| continue | |
| if not isinstance(node, ast.ClassDef): | |
| continue | |
| class_segment = extract_segment(source, node, lines) | |
| class_header = class_segment.splitlines()[0].strip() if class_segment else "" | |
| class_context = class_header | |
| for child in node.body: | |
| if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| continue | |
| segment = extract_segment(source, child, lines) | |
| if not segment: | |
| continue | |
| identifier = f"{relative_path}:{node.name}.{child.name}" | |
| combined = f"{class_context}\n{segment}" if class_context else segment | |
| sanitized = detector._sanitize_for_embedding(combined, model_hint, node.name) | |
| identifiers.append(identifier) | |
| sanitized_sources.append(sanitized) | |
| tokens_map[identifier] = sorted(detector._tokenize(sanitized)) | |
| if not identifiers: | |
| raise SystemExit("No modeling methods found.") | |
| print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}") | |
| embeddings = analyzer.encode(sanitized_sources) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors") | |
| with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file: | |
| json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) | |
| with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file: | |
| json.dump(tokens_map, file) | |
| def main() -> None: | |
| transformers_dir = ROOT / "transformers" | |
| if not transformers_dir.exists(): | |
| transformers_dir = ROOT / "transformers_repo" | |
| if not transformers_dir.exists(): | |
| raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo") | |
| detector = load_detector(transformers_dir) | |
| detector.EMBEDDING_MODEL = EMBEDDING_MODEL | |
| hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT) | |
| analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset) | |
| analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve() | |
| analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32 | |
| build_method_index(detector, analyzer, analyzer.models_root, ROOT) | |
| if __name__ == "__main__": | |
| main() | |