File size: 5,029 Bytes
4340743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3

from __future__ import annotations

import ast
import importlib.util
import json
import os
from pathlib import Path

import numpy as np
import torch
from safetensors.numpy import save_file as safetensors_save


ROOT = Path(__file__).resolve().parent.parent
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"


def load_detector(transformers_dir: Path):
    module_path = transformers_dir / "utils" / "modular_model_detector.py"
    if not module_path.exists():
        raise SystemExit(f"Missing modular_model_detector.py at {module_path}")
    spec = importlib.util.spec_from_file_location("modular_model_detector", module_path)
    if spec is None or spec.loader is None:
        raise SystemExit(f"Could not load detector from {module_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None:
    segment = ast.get_source_segment(source, node)
    if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
        start = max(0, node.lineno - 1)
        end = node.end_lineno
        segment = "\n".join(lines[start:end])
    return segment


def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None:
    identifiers: list[str] = []
    sanitized_sources: list[str] = []
    tokens_map: dict[str, list[str]] = {}

    modeling_files = sorted(models_root.rglob("modeling_*.py"))
    for file_path in modeling_files:
        try:
            source = file_path.read_text(encoding="utf-8")
        except OSError:
            continue
        try:
            tree = ast.parse(source)
        except SyntaxError:
            continue

        lines = source.splitlines()
        model_hint = analyzer._infer_model_from_relative_path(file_path)
        relative_path = file_path.relative_to(models_root).as_posix()

        for node in ast.iter_child_nodes(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                segment = extract_segment(source, node, lines)
                if not segment:
                    continue
                identifier = f"{relative_path}:{node.name}"
                sanitized = detector._sanitize_for_embedding(segment, model_hint, node.name)
                identifiers.append(identifier)
                sanitized_sources.append(sanitized)
                tokens_map[identifier] = sorted(detector._tokenize(sanitized))
                continue

            if not isinstance(node, ast.ClassDef):
                continue

            class_segment = extract_segment(source, node, lines)
            class_header = class_segment.splitlines()[0].strip() if class_segment else ""
            class_context = class_header

            for child in node.body:
                if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
                    continue
                segment = extract_segment(source, child, lines)
                if not segment:
                    continue
                identifier = f"{relative_path}:{node.name}.{child.name}"
                combined = f"{class_context}\n{segment}" if class_context else segment
                sanitized = detector._sanitize_for_embedding(combined, model_hint, node.name)
                identifiers.append(identifier)
                sanitized_sources.append(sanitized)
                tokens_map[identifier] = sorted(detector._tokenize(sanitized))

    if not identifiers:
        raise SystemExit("No modeling methods found.")

    print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}")
    embeddings = analyzer.encode(sanitized_sources)

    output_dir.mkdir(parents=True, exist_ok=True)
    safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors")
    with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file:
        json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
    with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file:
        json.dump(tokens_map, file)


def main() -> None:
    transformers_dir = ROOT / "transformers"
    if not transformers_dir.exists():
        transformers_dir = ROOT / "transformers_repo"
    if not transformers_dir.exists():
        raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo")

    detector = load_detector(transformers_dir)
    detector.EMBEDDING_MODEL = EMBEDDING_MODEL
    hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT)
    analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset)
    analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve()
    analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32
    build_method_index(detector, analyzer, analyzer.models_root, ROOT)


if __name__ == "__main__":
    main()