Spaces:
Sleeping
Sleeping
| import ast | |
| import json | |
| import math | |
| import os | |
| import re | |
| from typing import Callable, Iterable | |
| from functools import cache | |
| from pathlib import Path | |
| import numpy as np | |
| from huggingface_hub import snapshot_download | |
| from safetensors.numpy import load_file as safetensors_load | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import transformers | |
| _LIB_PATH = Path(transformers.__file__).resolve().parent | |
| _ENV_REPO = os.getenv("TRANSFORMERS_REPO") | |
| if _ENV_REPO: | |
| _env_path = Path(_ENV_REPO) | |
| _candidate = _env_path / "src" / "transformers" / "models" | |
| if _candidate.exists(): | |
| MODELS_ROOT = _candidate | |
| else: | |
| _fallback = _env_path / "models" | |
| MODELS_ROOT = _fallback if _fallback.exists() else _LIB_PATH / "models" | |
| else: | |
| MODELS_ROOT = _LIB_PATH / "models" | |
| EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" | |
| BATCH_SIZE = 16 | |
| MAX_LENGTH = 4096 | |
| HUB_DATASET_DEFAULT = "Molbap/modular-detector-embeddings" | |
| BOILERPLATE_NAMES = { | |
| "__init__", | |
| "_init_weights", | |
| "__repr__", | |
| "extra_repr", | |
| "get_input_embeddings", | |
| "set_input_embeddings", | |
| "get_output_embeddings", | |
| "set_output_embeddings", | |
| "tie_weights", | |
| "post_init", | |
| "forward", | |
| "init_weights", | |
| "reset_parameters", | |
| "training", | |
| } | |
| _RE_COMMENT = re.compile(r"#.*") | |
| _RE_IMPORT = re.compile(r"\s*(from|import)\s+") | |
| _RE_MODEL_HINT = re.compile(r"\d+") | |
| _RE_LEADING_PREFIX = re.compile(r"^([A-Z][a-z0-9]+)") | |
| _RE_ALPHANUM = re.compile(r"^([A-Za-z0-9]+)") | |
| _RE_NORMALIZE = re.compile(r"[^a-z0-9]+") | |
| def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: | |
| code = _strip_docstrings(code) | |
| cleaned = _RE_COMMENT.sub("", code) | |
| base = "\n".join( | |
| line for line in cleaned.splitlines() if line.strip() and not _RE_IMPORT.match(line) | |
| ) | |
| variants = set() | |
| if model_hint: | |
| variants.add(model_hint) | |
| variants.add(model_hint.replace("_", "")) | |
| variants.add(_RE_MODEL_HINT.sub("", model_hint)) | |
| if symbol_hint: | |
| match = _RE_LEADING_PREFIX.match(symbol_hint) or _RE_ALPHANUM.match(symbol_hint) | |
| prefix = match.group(1) if match else "" | |
| if prefix: | |
| variants.add(prefix) | |
| variants.add(prefix.replace("_", "")) | |
| variants.add(_RE_MODEL_HINT.sub("", prefix)) | |
| variants |= {variant.lower() for variant in list(variants)} | |
| sanitized = base | |
| for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True): | |
| sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE) | |
| return sanitized | |
| def _normalize(value: str | None) -> str: | |
| return _RE_NORMALIZE.sub("", value.lower()) if value else "" | |
| def _leading_prefix(name: str) -> str: | |
| match = _RE_LEADING_PREFIX.match(name) or _RE_ALPHANUM.match(name) | |
| return match.group(1) if match else "" | |
| def _infer_model_hint(definitions_kind: dict[str, str]) -> str | None: | |
| counts: dict[str, int] = {} | |
| for identifier, kind in definitions_kind.items(): | |
| if kind == "class": | |
| name = identifier | |
| elif kind == "method": | |
| name = identifier.split(".", 1)[0] | |
| else: | |
| continue | |
| prefix = _leading_prefix(name) | |
| if prefix: | |
| counts[prefix] = counts.get(prefix, 0) + 1 | |
| if not counts: | |
| return None | |
| return max(counts.items(), key=lambda item: item[1])[0] | |
| def _infer_model_prefixes(definitions_kind: dict[str, str]) -> set[str]: | |
| prefixes: set[str] = set() | |
| for identifier, kind in definitions_kind.items(): | |
| if kind == "class": | |
| name = identifier | |
| elif kind == "method": | |
| name = identifier.split(".", 1)[0] | |
| else: | |
| continue | |
| prefix = _leading_prefix(name) | |
| norm = _normalize(prefix) | |
| if norm: | |
| prefixes.add(norm) | |
| return prefixes | |
| def _calculate_reconstruction_score( | |
| contributors: Iterable[dict[str, object]], | |
| query_method_count: int, | |
| ) -> tuple[float, int]: | |
| if query_method_count <= 0: | |
| return 0.0, 0 | |
| best_scores: dict[str, float] = {} | |
| for contributor in contributors: | |
| query_name = contributor.get("query") | |
| if not query_name: | |
| continue | |
| score = float(contributor.get("score", 0.0)) | |
| if score > best_scores.get(str(query_name), 0.0): | |
| best_scores[str(query_name)] = score | |
| total_similarity = sum(best_scores.values()) | |
| return total_similarity / float(query_method_count), len(best_scores) | |
| def _normalize_source_path(path: str | None) -> str | None: | |
| if not path: | |
| return None | |
| raw = Path(path) | |
| if raw.is_absolute(): | |
| try: | |
| return str(raw.resolve().relative_to(MODELS_ROOT.resolve())) | |
| except Exception: | |
| return raw.as_posix() | |
| parts = raw.parts | |
| if "models" in parts: | |
| idx = parts.index("models") | |
| return "/".join(parts[idx + 1 :]) | |
| return raw.as_posix() | |
| def _strip_docstrings(source: str) -> str: | |
| try: | |
| tree = ast.parse(source) | |
| except SyntaxError: | |
| return source | |
| def strip_in_body(body: list[ast.stmt]) -> None: | |
| if not body: | |
| return | |
| first = body[0] | |
| if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant): | |
| if isinstance(first.value.value, str): | |
| body.pop(0) | |
| strip_in_body(tree.body) | |
| for node in ast.walk(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
| strip_in_body(node.body) | |
| ast.fix_missing_locations(tree) | |
| try: | |
| return ast.unparse(tree) | |
| except Exception: | |
| return source | |
| def _strip_docstrings_in_tree(tree: ast.AST) -> None: | |
| def strip_in_body(body: list[ast.stmt]) -> None: | |
| if not body: | |
| return | |
| first = body[0] | |
| if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant): | |
| if isinstance(first.value.value, str): | |
| body.pop(0) | |
| strip_in_body(tree.body) | |
| for node in ast.walk(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
| strip_in_body(node.body) | |
| def _sanitize_unparsed_code(code: str, model_hint: str | None, symbol_hint: str | None) -> str: | |
| cleaned = _RE_COMMENT.sub("", code) | |
| base = "\n".join(line for line in cleaned.splitlines() if line.strip() and not _RE_IMPORT.match(line)) | |
| hints: set[str] = set() | |
| if model_hint: | |
| hints.add(model_hint) | |
| hints.add(model_hint.replace("_", "")) | |
| hints.add(_RE_MODEL_HINT.sub("", model_hint)) | |
| if symbol_hint: | |
| match = _RE_LEADING_PREFIX.match(symbol_hint) or _RE_ALPHANUM.match(symbol_hint) | |
| prefix = match.group(1) if match else "" | |
| if prefix: | |
| hints.add(prefix) | |
| hints.add(prefix.replace("_", "")) | |
| hints.add(_RE_MODEL_HINT.sub("", prefix)) | |
| hints = {h for h in hints if len(h) >= 3} | |
| if hints: | |
| pattern = re.compile("|".join(re.escape(h) for h in sorted(hints, key=len, reverse=True)), re.IGNORECASE) | |
| base = pattern.sub("Model", base) | |
| return base | |
| def _normalize_code_for_compare(source: str) -> str: | |
| stripped = _strip_docstrings(source) | |
| return "".join(line.strip() for line in stripped.splitlines() if line.strip()) | |
| def _load_source(relative_path: str) -> str: | |
| file_path = MODELS_ROOT / relative_path | |
| try: | |
| return file_path.read_text(encoding="utf-8") | |
| except (FileNotFoundError, OSError): | |
| return "" | |
| def _get_definition_segment(relative_path: str, match_name: str) -> str | None: | |
| source = _load_source(relative_path) | |
| if not source: | |
| return None | |
| try: | |
| tree = ast.parse(source) | |
| except SyntaxError: | |
| return None | |
| lines = source.splitlines() | |
| target = None | |
| if "." in match_name: | |
| class_name, method_name = match_name.split(".", 1) | |
| for node in tree.body: | |
| if isinstance(node, ast.ClassDef) and node.name == class_name: | |
| for child in node.body: | |
| if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == method_name: | |
| target = child | |
| break | |
| else: | |
| for node in tree.body: | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == match_name: | |
| target = node | |
| break | |
| if target is None: | |
| return None | |
| segment = ast.get_source_segment(source, target) | |
| if segment is None and hasattr(target, "lineno") and hasattr(target, "end_lineno"): | |
| start = max(0, target.lineno - 1) | |
| end = target.end_lineno | |
| segment = "\n".join(lines[start:end]) | |
| return segment | |
| def _load_definition_line_map(relative_path: str) -> dict[str, int]: | |
| file_path = MODELS_ROOT / relative_path | |
| try: | |
| source = file_path.read_text(encoding="utf-8") | |
| except (FileNotFoundError, OSError): | |
| return {} | |
| try: | |
| tree = ast.parse(source) | |
| except SyntaxError: | |
| return {} | |
| line_map: dict[str, int] = {} | |
| for node in ast.iter_child_nodes(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
| line_map[node.name] = getattr(node, "lineno", None) or 1 | |
| if isinstance(node, ast.ClassDef): | |
| for child in node.body: | |
| if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| key = f"{node.name}.{child.name}" | |
| line_map[key] = getattr(child, "lineno", None) or 1 | |
| return line_map | |
| def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, int | None]: | |
| full_path = (MODELS_ROOT / relative_path).resolve() | |
| line = _load_definition_line_map(relative_path).get(definition) | |
| return str(full_path), line | |
| class CodeSimilarityAnalyzer: | |
| def __init__(self, hub_dataset: str, precision: str = "float32", granularity: str = "method"): | |
| self.hub_dataset = hub_dataset | |
| self.precision = precision | |
| self.requested_granularity = granularity | |
| self.index_granularity = granularity | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 | |
| self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL, trust_remote_code=True) | |
| self.model = AutoModel.from_pretrained( | |
| EMBEDDING_MODEL, trust_remote_code=True, dtype=self.dtype, device_map=None | |
| ).to(self.device) | |
| self.model.eval() | |
| self.index_dir: Path | None = None | |
| self.index_origin: str | None = None | |
| self.missing_files: tuple[str, ...] = () | |
| self._index_cache: dict[str, object] | None = None | |
| def _embedding_filename(self, granularity: str | None = None) -> str: | |
| granularity = granularity or self.index_granularity | |
| suffix = "" | |
| if granularity == "method": | |
| suffix += "_methods" | |
| if self.precision == "int8": | |
| suffix += "_int8" | |
| if not suffix: | |
| return "embeddings.safetensors" | |
| return f"embeddings{suffix}.safetensors" | |
| def _index_map_filename(self, granularity: str | None = None) -> str: | |
| granularity = granularity or self.index_granularity | |
| if granularity == "method": | |
| return "code_index_map_methods.json" | |
| return "code_index_map.json" | |
| def _resolve_index_path(self, filename: str) -> Path: | |
| if self.index_dir is None: | |
| return Path(filename) | |
| return self.index_dir / filename | |
| def _required_index_files(self, granularity: str | None = None) -> tuple[str, ...]: | |
| return ( | |
| self._embedding_filename(granularity), | |
| self._index_map_filename(granularity), | |
| ) | |
| def ensure_local_index(self) -> None: | |
| required_files = self._required_index_files(self.requested_granularity) | |
| if self.index_dir is not None and all((self.index_dir / fname).exists() for fname in required_files): | |
| return | |
| def missing_files(directory: Path, granularity: str) -> list[str]: | |
| return [fname for fname in self._required_index_files(granularity) if not (directory / fname).exists()] | |
| candidates: list[tuple[str, Path]] = [] | |
| env_dir = os.getenv("INDEX_DIR") | |
| if env_dir: | |
| candidates.append(("env", Path(env_dir))) | |
| candidates.append(("cwd", Path.cwd())) | |
| candidates.append(("repo", Path(__file__).resolve().parent.parent)) | |
| missing_preferred: list[str] = [] | |
| for origin, candidate in candidates: | |
| missing_preferred = missing_files(candidate, self.requested_granularity) | |
| if not missing_preferred: | |
| self.index_dir = candidate | |
| self.index_origin = origin | |
| self.index_granularity = self.requested_granularity | |
| self.missing_files = () | |
| self._index_cache = None | |
| return | |
| fallback_dir: Path | None = None | |
| fallback_origin: str | None = None | |
| fallback_missing: list[str] = [] | |
| if self.requested_granularity == "method": | |
| for origin, candidate in candidates: | |
| fallback_missing = missing_files(candidate, "definition") | |
| if not fallback_missing: | |
| fallback_dir = candidate | |
| fallback_origin = origin | |
| break | |
| snapshot_dir = Path(snapshot_download(repo_id=self.hub_dataset, repo_type="dataset")) | |
| hub_missing_preferred = missing_files(snapshot_dir, self.requested_granularity) | |
| hub_missing_fallback: list[str] = [] | |
| if self.requested_granularity == "method": | |
| hub_missing_fallback = missing_files(snapshot_dir, "definition") | |
| if not hub_missing_preferred: | |
| self.index_dir = snapshot_dir | |
| self.index_origin = "hub" | |
| self.index_granularity = self.requested_granularity | |
| self.missing_files = () | |
| self._index_cache = None | |
| return | |
| if self.requested_granularity == "method" and not hub_missing_fallback: | |
| self.index_dir = snapshot_dir | |
| self.index_origin = "hub" | |
| self.index_granularity = "definition" | |
| self.missing_files = tuple(hub_missing_preferred) | |
| self._index_cache = None | |
| return | |
| if fallback_dir is not None: | |
| self.index_dir = fallback_dir | |
| self.index_origin = fallback_origin | |
| self.index_granularity = "definition" | |
| self.missing_files = tuple(missing_preferred) | |
| self._index_cache = None | |
| return | |
| missing_detail = ", ".join(hub_missing_preferred or missing_preferred) | |
| raise FileNotFoundError( | |
| "Missing expected files for requested granularity; unable to fall back to definition index. " | |
| f"Missing: {missing_detail}" | |
| ) | |
| def _load_index(self) -> dict[str, object]: | |
| if self._index_cache is not None: | |
| return self._index_cache | |
| self.ensure_local_index() | |
| embedding_path = self._resolve_index_path(self._embedding_filename()) | |
| base = safetensors_load(str(embedding_path)) | |
| base_embeddings = base["embeddings"] | |
| scales = base.get("scale") if self.precision == "int8" else None | |
| with open(self._resolve_index_path(self._index_map_filename()), "r", encoding="utf-8") as file: | |
| identifier_map = {int(key): value for key, value in json.load(file).items()} | |
| self._index_cache = { | |
| "embeddings": base_embeddings, | |
| "scales": scales, | |
| "identifier_map": identifier_map, | |
| } | |
| return self._index_cache | |
| def _encode_batch(self, texts: list[str]) -> np.ndarray: | |
| encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") | |
| encoded = {key: value.to(self.device) for key, value in encoded.items()} | |
| with torch.no_grad(): | |
| if self.device.type == "cuda": | |
| with torch.autocast(device_type="cuda", dtype=self.dtype): | |
| output = self.model(**encoded) | |
| else: | |
| output = self.model(**encoded) | |
| if hasattr(output, "last_hidden_state"): | |
| embeddings = output.last_hidden_state | |
| mask = encoded["attention_mask"].unsqueeze(-1) | |
| embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9) | |
| elif hasattr(output, "pooler_output"): | |
| embeddings = output.pooler_output | |
| else: | |
| embeddings = output[0].mean(dim=1) | |
| embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) | |
| return embeddings.cpu().numpy().astype("float32") | |
| def encode( | |
| self, | |
| texts: list[str], | |
| progress: Callable[[int, int, str], None] | None = None, | |
| progress_offset: int = 0, | |
| progress_total: int | None = None, | |
| ) -> np.ndarray: | |
| if not texts: | |
| return np.zeros((0, 0), dtype="float32") | |
| output = [] | |
| for i in range(0, len(texts), BATCH_SIZE): | |
| output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) | |
| if progress and progress_total is not None: | |
| batch_index = i // BATCH_SIZE | |
| progress(progress_offset + batch_index + 1, progress_total, "encoding") | |
| if self.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return np.vstack(output) if output else np.zeros((0, 0), dtype="float32") | |
| def _topk( | |
| self, | |
| query_embedding_row: np.ndarray, | |
| base_embeddings: np.ndarray, | |
| scales: np.ndarray | None, | |
| identifier_map: dict[int, str], | |
| self_model_normalized: set[str], | |
| self_name: str, | |
| exclude_same_file: str | None, | |
| k: int, | |
| pool_size: int | None = None, | |
| ) -> list[tuple[str, float]]: | |
| if self.precision == "int8": | |
| if scales is None: | |
| raise ValueError("Missing int8 scales for int8 search.") | |
| weighted_query = (query_embedding_row * scales).astype("float32") | |
| similarities = weighted_query @ base_embeddings.T.astype("float32") | |
| else: | |
| similarities = query_embedding_row @ base_embeddings.T | |
| pool = k + 32 if pool_size is None else max(k, pool_size) | |
| indices = np.argpartition(-similarities, pool)[:pool] | |
| indices = indices[np.argsort(-similarities[indices])] | |
| output = [] | |
| for match_id in indices: | |
| identifier = identifier_map[int(match_id)] | |
| if ":" not in identifier: | |
| continue | |
| relative_path, match_name = identifier.split(":", 1) | |
| if exclude_same_file and relative_path == exclude_same_file: | |
| continue | |
| if match_name == self_name: | |
| continue | |
| if self_model_normalized: | |
| parent_model = Path(relative_path).parts[0] if relative_path else "" | |
| if _normalize(parent_model) in self_model_normalized: | |
| continue | |
| output.append((identifier, float(similarities[match_id]))) | |
| if len(output) >= k: | |
| break | |
| return output | |
| def _extract_definitions_from_code( | |
| self, | |
| code: str, | |
| model_hint: str | None, | |
| granularity: str, | |
| ) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: | |
| definitions_raw: dict[str, str] = {} | |
| definitions_sanitized: dict[str, str] = {} | |
| definitions_kind: dict[str, str] = {} | |
| lines = code.splitlines() | |
| tree = ast.parse(code) | |
| entries: list[tuple[str, str, ast.AST, ast.ClassDef | None]] = [] | |
| for node in ast.iter_child_nodes(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and granularity in ("definition", "method"): | |
| segment = ast.get_source_segment(code, node) | |
| if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): | |
| start = max(0, node.lineno - 1) | |
| end = node.end_lineno | |
| segment = "\n".join(lines[start:end]) | |
| if not segment: | |
| continue | |
| identifier = node.name | |
| definitions_raw[identifier] = segment | |
| definitions_kind[identifier] = "function" | |
| entries.append((identifier, "function", node, None)) | |
| continue | |
| if isinstance(node, ast.ClassDef): | |
| class_segment = ast.get_source_segment(code, node) | |
| if class_segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): | |
| start = max(0, node.lineno - 1) | |
| end = node.end_lineno | |
| class_segment = "\n".join(lines[start:end]) | |
| if not class_segment: | |
| continue | |
| if granularity == "definition": | |
| identifier = node.name | |
| definitions_raw[identifier] = class_segment | |
| definitions_kind[identifier] = "class" | |
| entries.append((identifier, "class", node, None)) | |
| continue | |
| for child in node.body: | |
| if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| continue | |
| segment = ast.get_source_segment(code, child) | |
| if segment is None and hasattr(child, "lineno") and hasattr(child, "end_lineno"): | |
| start = max(0, child.lineno - 1) | |
| end = child.end_lineno | |
| segment = "\n".join(lines[start:end]) | |
| if not segment: | |
| continue | |
| method_name = child.name | |
| identifier = f"{node.name}.{method_name}" | |
| definitions_raw[identifier] = segment | |
| definitions_kind[identifier] = "method" | |
| entries.append((identifier, "method", child, node)) | |
| _strip_docstrings_in_tree(tree) | |
| for identifier, kind, node, parent in entries: | |
| try: | |
| if kind == "method" and parent is not None: | |
| parent_header = ast.unparse(parent).splitlines()[0] | |
| combined = f"{parent_header}\n{ast.unparse(node)}" | |
| sanitized = _sanitize_unparsed_code(combined, model_hint, parent.name) | |
| else: | |
| sanitized = _sanitize_unparsed_code(ast.unparse(node), model_hint, identifier) | |
| except Exception: | |
| sanitized = definitions_raw.get(identifier, "") | |
| definitions_sanitized[identifier] = sanitized | |
| return definitions_raw, definitions_sanitized, definitions_kind | |
| def analyze_code( | |
| self, | |
| code: str, | |
| top_k_per_item: int = 5, | |
| model_hint: str | None = None, | |
| exclude_same_model: bool = True, | |
| exclude_identical: bool = True, | |
| exclude_models: list[str] | None = None, | |
| source_path: str | None = None, | |
| progress: Callable[[int, int, str], None] | None = None, | |
| ) -> dict[str, dict[str, object]]: | |
| if progress: | |
| progress(0, 1, "starting") | |
| index_data = self._load_index() | |
| base_embeddings = index_data["embeddings"] | |
| scales = index_data["scales"] | |
| identifier_map = index_data["identifier_map"] | |
| definitions_raw, definitions_sanitized, definitions_kind = self._extract_definitions_from_code( | |
| code, model_hint, self.index_granularity | |
| ) | |
| if model_hint is None: | |
| model_hint = _infer_model_hint(definitions_kind) | |
| self_model_normalized = _infer_model_prefixes(definitions_kind) if exclude_same_model else set() | |
| if exclude_same_model and model_hint: | |
| self_model_normalized.add(_normalize(model_hint)) | |
| if exclude_models: | |
| self_model_normalized.update({_normalize(name) for name in exclude_models}) | |
| exclude_same_file = _normalize_source_path(source_path) | |
| query_identifiers = list(definitions_raw.keys()) | |
| query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers] | |
| query_compare = { | |
| key: _normalize_code_for_compare(definitions_raw[key]) for key in query_identifiers if key in definitions_raw | |
| } | |
| total_batches = max(1, math.ceil(len(query_sources_sanitized) / BATCH_SIZE)) | |
| total_steps = 2 + total_batches | |
| if progress: | |
| progress(1, total_steps, "parsed") | |
| query_embeddings = self.encode( | |
| query_sources_sanitized, | |
| progress=progress, | |
| progress_offset=1, | |
| progress_total=total_steps, | |
| ) | |
| output = {} | |
| output_all = {} | |
| identical_filtered = 0 | |
| for i, query_identifier in enumerate(query_identifiers): | |
| pool_size = max(top_k_per_item * 5, top_k_per_item + 32) | |
| candidates = self._topk( | |
| query_embeddings[i], | |
| base_embeddings, | |
| scales, | |
| identifier_map, | |
| self_model_normalized, | |
| query_identifier, | |
| exclude_same_file, | |
| pool_size, | |
| pool_size=pool_size, | |
| ) | |
| entry: dict[str, object] = { | |
| "kind": definitions_kind.get(query_identifier, "function"), | |
| "embedding": [], | |
| } | |
| entry_all: dict[str, object] = { | |
| "kind": definitions_kind.get(query_identifier, "function"), | |
| "embedding": [], | |
| } | |
| for identifier, score in candidates: | |
| relative_path, match_name = identifier.split(":", 1) | |
| is_identical = False | |
| match_segment = None | |
| if exclude_identical or query_compare: | |
| match_segment = _get_definition_segment(relative_path, match_name) | |
| if match_segment is not None: | |
| match_norm = _normalize_code_for_compare(match_segment) | |
| query_norm = query_compare.get(query_identifier) | |
| if query_norm and match_norm == query_norm: | |
| is_identical = True | |
| if len(entry_all["embedding"]) < top_k_per_item: | |
| full_path, line = _resolve_definition_location(relative_path, match_name) | |
| entry_all["embedding"].append( | |
| { | |
| "identifier": identifier, | |
| "relative_path": relative_path, | |
| "match_name": match_name, | |
| "score": score, | |
| "full_path": full_path, | |
| "line": line, | |
| "is_identical": is_identical, | |
| } | |
| ) | |
| if exclude_identical and is_identical: | |
| identical_filtered += 1 | |
| continue | |
| full_path, line = _resolve_definition_location(relative_path, match_name) | |
| entry["embedding"].append( | |
| { | |
| "identifier": identifier, | |
| "relative_path": relative_path, | |
| "match_name": match_name, | |
| "score": score, | |
| "full_path": full_path, | |
| "line": line, | |
| "is_identical": is_identical, | |
| } | |
| ) | |
| if len(entry["embedding"]) >= top_k_per_item and len(entry_all["embedding"]) >= top_k_per_item: | |
| break | |
| output[query_identifier] = entry | |
| output_all[query_identifier] = entry_all | |
| def build_by_class(result_map: dict[str, dict[str, object]]) -> dict[str, list[dict[str, object]]]: | |
| def candidate_class_key(relative_path: str, match_name: str) -> tuple[str, str]: | |
| class_name = match_name.split(".", 1)[0] if "." in match_name else "<module>" | |
| return relative_path, class_name | |
| def query_class_key(query_identifier: str, kind: str) -> str: | |
| return query_identifier.split(".", 1)[0] if kind == "method" else query_identifier | |
| by_class: dict[str, dict[tuple[str, str], dict[str, object]]] = {} | |
| for query_identifier, entry in result_map.items(): | |
| kind = entry.get("kind", "function") | |
| if kind == "function": | |
| matches = entry.get("embedding", []) | |
| if not matches: | |
| continue | |
| best_per_cand: dict[tuple[str, str], dict[str, object]] = {} | |
| for match in matches: | |
| rel = match.get("relative_path") | |
| mname = match.get("match_name") | |
| score = match.get("score") | |
| if rel is None or not mname or score is None: | |
| continue | |
| if "." in mname: | |
| continue | |
| ckey = (rel, mname) | |
| prev = best_per_cand.get(ckey) | |
| if prev is None or float(score) > float(prev.get("score", -1.0)): | |
| best_per_cand[ckey] = match | |
| for ckey, match in best_per_cand.items(): | |
| slot = by_class.setdefault(query_identifier, {}).setdefault( | |
| ckey, | |
| { | |
| "relative_path": ckey[0], | |
| "class_name": ckey[1], | |
| "scores": [], | |
| "contributors": [], | |
| }, | |
| ) | |
| slot["scores"].append(float(match["score"])) | |
| slot["contributors"].append( | |
| {"query": query_identifier, "match": match["identifier"], "score": float(match["score"])} | |
| ) | |
| continue | |
| qcls = query_class_key(query_identifier, kind) | |
| matches = entry.get("embedding", []) | |
| if not matches: | |
| continue | |
| best_per_cand: dict[tuple[str, str], dict[str, object]] = {} | |
| for match in matches: | |
| rel = match.get("relative_path") | |
| mname = match.get("match_name") | |
| score = match.get("score") | |
| if rel is None or not mname or score is None: | |
| continue | |
| ckey = candidate_class_key(rel, mname) | |
| prev = best_per_cand.get(ckey) | |
| if prev is None or float(score) > float(prev.get("score", -1.0)): | |
| best_per_cand[ckey] = match | |
| for ckey, match in best_per_cand.items(): | |
| slot = by_class.setdefault(qcls, {}).setdefault( | |
| ckey, | |
| { | |
| "relative_path": ckey[0], | |
| "class_name": ckey[1], | |
| "scores": [], | |
| "contributors": [], | |
| }, | |
| ) | |
| slot["scores"].append(float(match["score"])) | |
| slot["contributors"].append( | |
| {"query": query_identifier, "match": match["identifier"], "score": float(match["score"])} | |
| ) | |
| by_class_out: dict[str, list[dict[str, object]]] = {} | |
| for qcls, cand_map in by_class.items(): | |
| q_method_count = len( | |
| [ | |
| key | |
| for key, kind in definitions_kind.items() | |
| if kind == "method" | |
| and key.startswith(f"{qcls}.") | |
| and key.split(".")[-1] not in BOILERPLATE_NAMES | |
| ] | |
| ) | |
| q_method_count = max(1, q_method_count) | |
| rows = [] | |
| for _, slot in cand_map.items(): | |
| filtered_contributors = [ | |
| item | |
| for item in slot["contributors"] | |
| if str(item.get("query", "")).split(".")[-1] not in BOILERPLATE_NAMES | |
| ] | |
| base_score, coverage_count = _calculate_reconstruction_score( | |
| filtered_contributors, q_method_count | |
| ) | |
| coverage_ratio = coverage_count / float(q_method_count) | |
| contributors = sorted(slot["contributors"], key=lambda x: float(x["score"]), reverse=True)[:5] | |
| rows.append( | |
| { | |
| "relative_path": slot["relative_path"], | |
| "class_name": slot["class_name"], | |
| "identifier": f"{slot['relative_path']}:{slot['class_name']}", | |
| "score": base_score, | |
| "coverage": coverage_count, | |
| "coverage_pct": coverage_ratio, | |
| "top_contributors": contributors, | |
| } | |
| ) | |
| rows.sort(key=lambda row: (float(row["score"]), int(row["coverage"])), reverse=True) | |
| by_class_out[qcls] = rows[:10] | |
| return by_class_out | |
| def build_overall(result_map: dict[str, dict[str, object]]) -> list[dict[str, object]]: | |
| per_query_best: list[dict[str, float]] = [] | |
| for query_identifier, data in result_map.items(): | |
| if query_identifier.split(".")[-1] in BOILERPLATE_NAMES: | |
| continue | |
| best_by_file: dict[str, float] = {} | |
| for match in data.get("embedding", []): | |
| rel = match.get("relative_path") | |
| score = match.get("score") | |
| if rel is None or score is None: | |
| continue | |
| best_by_file[rel] = max(best_by_file.get(rel, -1.0), float(score)) | |
| per_query_best.append(best_by_file) | |
| aggregate: dict[str, dict[str, object]] = {} | |
| for best_by_file in per_query_best: | |
| for rel, score in best_by_file.items(): | |
| slot = aggregate.setdefault(rel, {"relative_path": rel, "scores": []}) | |
| slot["scores"].append(score) | |
| overall = [] | |
| for rel, slot in aggregate.items(): | |
| scores = sorted(slot["scores"], reverse=True) | |
| overall.append( | |
| { | |
| "relative_path": rel, | |
| "score": float(sum(scores) / max(1, len(scores))), | |
| "count": len(scores), | |
| "best_score": float(scores[0]) if scores else -1.0, | |
| } | |
| ) | |
| overall.sort(key=lambda item: float(item["score"]), reverse=True) | |
| return overall | |
| by_class_out = build_by_class(output) | |
| by_class_all = build_by_class(output_all) | |
| overall = build_overall(output) | |
| overall_all = build_overall(output_all) | |
| if progress: | |
| progress(total_steps, total_steps, "done") | |
| return { | |
| "results": output, | |
| "overall": overall, | |
| "by_class": by_class_out, | |
| "identical_filtered": identical_filtered, | |
| "by_class_all": by_class_all, | |
| "overall_all": overall_all, | |
| } | |
| def index_status(self) -> dict[str, object]: | |
| return { | |
| "requested_granularity": self.requested_granularity, | |
| "resolved_granularity": self.index_granularity, | |
| "precision": self.precision, | |
| "hub_dataset": self.hub_dataset, | |
| "index_dir": str(self.index_dir) if self.index_dir else None, | |
| "index_origin": self.index_origin, | |
| "missing_files": list(self.missing_files), | |
| "embedding_model": EMBEDDING_MODEL, | |
| } | |
| def list_models() -> list[str]: | |
| models = [] | |
| for path in MODELS_ROOT.iterdir(): | |
| if path.is_dir() and not path.name.startswith("__"): | |
| models.append(path.name) | |
| return sorted(models) | |
| def generate_agent_report(analysis_results: dict) -> str: | |
| lines = [ | |
| "### MODULAR RECONSTRUCTION SPECIFICATION (FOR AGENT)", | |
| "You are operating inside the Transformers repository. Your task is to produce a `modular_*.py` file that", | |
| "maximizes reuse of existing Transformers components.", | |
| "Purpose: the `modular_model_converter.py` script reads the modular file, resolves inheritance, and", | |
| "generates the final `modeling_*.py` file. The modular file should prefer inheriting from existing", | |
| "Transformers classes (e.g., attention, MLP, blocks, model heads) rather than starting fresh from", | |
| "`nn.Module`, so the generated model shares as much code as possible and only overrides the minimal deltas.", | |
| "Use this report to choose those base classes and to decide where custom logic is actually required.", | |
| "", | |
| ] | |
| overall = analysis_results.get("overall", []) | |
| if overall: | |
| lines.append(f"Top architectural base: `{overall[0]['relative_path']}`") | |
| lines.append("\n#### 1. REUSE PLAN (INHERITANCE)") | |
| by_class = analysis_results.get("by_class", {}) | |
| if not by_class: | |
| lines.append("- (none)") | |
| for q_cls, matches in by_class.items(): | |
| if not matches: | |
| continue | |
| top = matches[0] | |
| if float(top.get("score", 0.0)) >= 0.85: | |
| lines.append( | |
| f"- **{q_cls}**: inherit from `{top['class_name']}` in `{top['relative_path']}`" | |
| ) | |
| lines.append( | |
| f" Evidence: {top['coverage']} methods matched (~{top['score']:.2f} similarity)." | |
| ) | |
| else: | |
| lines.append(f"- **{q_cls}**: custom implementation (no strong base class match).") | |
| lines.append("\n#### 2. DATAFLOW / STRUCTURAL ANCHORS") | |
| results = analysis_results.get("results", {}) | |
| if not results: | |
| lines.append("- (none)") | |
| for query_name, entry in results.items(): | |
| if query_name.split(".")[-1] in BOILERPLATE_NAMES: | |
| continue | |
| matches = entry.get("embedding", []) | |
| if matches and float(matches[0].get("score", 0.0)) >= 0.92: | |
| top = matches[0] | |
| lines.append( | |
| f"- `{query_name}` is near-identical to `{top['match_name']}` in `{top['relative_path']}`." | |
| ) | |
| lines.append(" Read both methods and reuse as much as possible through modular inheritance.") | |
| lines.append("\n#### 3. PROMPT STARTER") | |
| lines.append("```text") | |
| lines.append("I am building a modular transformers file. Based on similarity analysis:") | |
| lines.append(f"Base model: {overall[0]['relative_path'] if overall else 'Unknown'}") | |
| lines.append("Implement the modular classes using # Copied from tags where possible.") | |
| lines.append("```") | |
| return "\n".join(lines) | |
| def get_default_hub_dataset() -> str: | |
| return os.getenv("HUB_DATASET", HUB_DATASET_DEFAULT) | |