File size: 5,029 Bytes
5e4a986
 
 
 
 
 
 
 
 
 
 
7a4622c
5e4a986
7a4622c
5e4a986
 
afd3562
5e4a986
 
7a4622c
5e4a986
 
7a4622c
5e4a986
 
7a4622c
5e4a986
 
 
 
 
7a4622c
5e4a986
 
 
 
 
 
 
 
7a4622c
5e4a986
 
 
 
 
7a4622c
5e4a986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4622c
5e4a986
 
 
 
 
 
 
 
 
 
 
 
7a4622c
5e4a986
 
 
 
 
 
7a4622c
5e4a986
 
 
 
 
 
 
 
 
 
7a4622c
5e4a986
7a4622c
5e4a986
7a4622c
 
 
 
 
 
 
5e4a986
 
 
7a4622c
 
 
 
 
 
 
afd3562
5e4a986
 
7a4622c
 
 
5e4a986
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3

from __future__ import annotations

import ast
import importlib.util
import json
import os
from pathlib import Path

import numpy as np
import torch
from safetensors.numpy import save_file as safetensors_save


ROOT = Path(__file__).resolve().parent.parent
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"


def load_detector(transformers_dir: Path):
    module_path = transformers_dir / "utils" / "modular_model_detector.py"
    if not module_path.exists():
        raise SystemExit(f"Missing modular_model_detector.py at {module_path}")
    spec = importlib.util.spec_from_file_location("modular_model_detector", module_path)
    if spec is None or spec.loader is None:
        raise SystemExit(f"Could not load detector from {module_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def extract_segment(source: str, node: ast.AST, lines: list[str]) -> str | None:
    segment = ast.get_source_segment(source, node)
    if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
        start = max(0, node.lineno - 1)
        end = node.end_lineno
        segment = "\n".join(lines[start:end])
    return segment


def build_method_index(detector, analyzer, models_root: Path, output_dir: Path) -> None:
    identifiers: list[str] = []
    sanitized_sources: list[str] = []
    tokens_map: dict[str, list[str]] = {}

    modeling_files = sorted(models_root.rglob("modeling_*.py"))
    for file_path in modeling_files:
        try:
            source = file_path.read_text(encoding="utf-8")
        except OSError:
            continue
        try:
            tree = ast.parse(source)
        except SyntaxError:
            continue

        lines = source.splitlines()
        model_hint = analyzer._infer_model_from_relative_path(file_path)
        relative_path = file_path.relative_to(models_root).as_posix()

        for node in ast.iter_child_nodes(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                segment = extract_segment(source, node, lines)
                if not segment:
                    continue
                identifier = f"{relative_path}:{node.name}"
                sanitized = detector._sanitize_for_embedding(segment, model_hint, node.name)
                identifiers.append(identifier)
                sanitized_sources.append(sanitized)
                tokens_map[identifier] = sorted(detector._tokenize(sanitized))
                continue

            if not isinstance(node, ast.ClassDef):
                continue

            class_segment = extract_segment(source, node, lines)
            class_header = class_segment.splitlines()[0].strip() if class_segment else ""
            class_context = class_header

            for child in node.body:
                if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
                    continue
                segment = extract_segment(source, child, lines)
                if not segment:
                    continue
                identifier = f"{relative_path}:{node.name}.{child.name}"
                combined = f"{class_context}\n{segment}" if class_context else segment
                sanitized = detector._sanitize_for_embedding(combined, model_hint, node.name)
                identifiers.append(identifier)
                sanitized_sources.append(sanitized)
                tokens_map[identifier] = sorted(detector._tokenize(sanitized))

    if not identifiers:
        raise SystemExit("No modeling methods found.")

    print(f"Encoding {len(identifiers)} definitions (method) with {detector.EMBEDDING_MODEL}")
    embeddings = analyzer.encode(sanitized_sources)

    output_dir.mkdir(parents=True, exist_ok=True)
    safetensors_save({"embeddings": embeddings.astype("float32")}, output_dir / "embeddings_methods.safetensors")
    with open(output_dir / "code_index_map_methods.json", "w", encoding="utf-8") as file:
        json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
    with open(output_dir / "code_index_tokens_methods.json", "w", encoding="utf-8") as file:
        json.dump(tokens_map, file)


def main() -> None:
    transformers_dir = ROOT / "transformers"
    if not transformers_dir.exists():
        transformers_dir = ROOT / "transformers_repo"
    if not transformers_dir.exists():
        raise SystemExit("Expected a transformers clone at ./transformers or ./transformers_repo")

    detector = load_detector(transformers_dir)
    detector.EMBEDDING_MODEL = EMBEDDING_MODEL
    hub_dataset = os.getenv("HUB_DATASET", detector.HUB_DATASET_DEFAULT)
    analyzer = detector.CodeSimilarityAnalyzer(hub_dataset=hub_dataset)
    analyzer.models_root = (transformers_dir / "src" / "transformers" / "models").resolve()
    analyzer.dtype = torch.float16 if analyzer.device.type == "cuda" else torch.float32
    build_method_index(detector, analyzer, analyzer.models_root, ROOT)


if __name__ == "__main__":
    main()