Spaces:
Sleeping
Sleeping
File size: 5,029 Bytes
5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 afd3562 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c 5e4a986 7a4622c afd3562 5e4a986 7a4622c 5e4a986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
#!/usr/bin/env python3
from __future__ import annotations
import ast
import importlib.util
import json
import os
from pathlib import Path
import numpy as np
import torch
from safetensors.numpy import save_file as safetensors_save
ROOT = Path(__file__).resolve().parent.parent
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
def load_detector(transformers_dir: Path):
module_path = transformers_dir / "utils" / "modular_model_detector.py"
if not module_path.exists():
raise SystemExit(f"Missing modular_model_detector.py at {module_path}")
spec = importlib.util.spec_from_file_location("modular_model_detector", module_path)
if spec is None or spec.loader is None:
raise SystemExit(f"Could not load detector from {module_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None:
segment = ast.get_source_segment(source, node)
if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
start = max(0, node.lineno - 1)
end = node.end_lineno
segment = "\n".join(lines[start:end])
return segment
def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None:
identifiers: list[str] = []
sanitized_sources: list[str] = []
tokens_map: dict[str, list[str]] = {}
modeling_files = sorted(models_root.rglob("modeling_*.py"))
for file_path in modeling_files:
try:
source = file_path.read_text(encoding="utf-8")
except OSError:
continue
try:
tree = ast.parse(source)
except SyntaxError:
continue
lines = source.splitlines()
model_hint = analyzer._infer_model_from_relative_path(file_path)
relative_path = file_path.relative_to(models_root).as_posix()
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
segment = extract_segment(source, node, lines)
if not segment:
continue
identifier = f"{relative_path}:{node.name}"
sanitized = detector._sanitize_for_embedding(segment, model_hint, node.name)
identifiers.append(identifier)
sanitized_sources.append(sanitized)
tokens_map[identifier] = sorted(detector._tokenize(sanitized))
continue
if not isinstance(node, ast.ClassDef):
continue
class_segment = extract_segment(source, node, lines)
class_header = class_segment.splitlines()[0].strip() if class_segment else ""
class_context = class_header
for child in node.body:
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
segment = extract_segment(source, child, lines)
if not segment:
continue
identifier = f"{relative_path}:{node.name}.{child.name}"
combined = f"{class_context}\n{segment}" if class_context else segment
sanitized = detector._sanitize_for_embedding(combined, model_hint, node.name)
identifiers.append(identifier)
sanitized_sources.append(sanitized)
tokens_map[identifier] = sorted(detector._tokenize(sanitized))
if not identifiers:
raise SystemExit("No modeling methods found.")
print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}")
embeddings = analyzer.encode(sanitized_sources)
output_dir.mkdir(parents=True, exist_ok=True)
safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors")
with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file:
json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file:
json.dump(tokens_map, file)
def main() -> None:
transformers_dir = ROOT / "transformers"
if not transformers_dir.exists():
transformers_dir = ROOT / "transformers_repo"
if not transformers_dir.exists():
raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo")
detector = load_detector(transformers_dir)
detector.EMBEDDING_MODEL = EMBEDDING_MODEL
hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT)
analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset)
analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve()
analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32
build_method_index(detector, analyzer, analyzer.models_root, ROOT)
if __name__ == "__main__":
main()
|