import ast import json import math import os import re from typing import Callable, Iterable from functools import cache from pathlib import Path import numpy as np from huggingface_hub import snapshot_download from safetensors.numpy import load_file as safetensors_load import torch from transformers import AutoModel, AutoTokenizer import transformers _LIB_PATH = Path(transformers.__file__).resolve().parent _ENV_REPO = os.getenv("TRANSFORMERS_REPO") if _ENV_REPO: _env_path = Path(_ENV_REPO) _candidate = _env_path / "src" / "transformers" / "models" if _candidate.exists(): MODELS_ROOT = _candidate else: _fallback = _env_path / "models" MODELS_ROOT = _fallback if _fallback.exists() else _LIB_PATH / "models" else: MODELS_ROOT = _LIB_PATH / "models" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" BATCH_SIZE = 16 MAX_LENGTH = 4096 HUB_DATASET_DEFAULT = "Molbap/modular-detector-embeddings" BOILERPLATE_NAMES = { "__init__", "_init_weights", "__repr__", "extra_repr", "get_input_embeddings", "set_input_embeddings", "get_output_embeddings", "set_output_embeddings", "tie_weights", "post_init", "forward", "init_weights", "reset_parameters", "training", } _RE_COMMENT = re.compile(r"#.*") _RE_IMPORT = re.compile(r"\s*(from|import)\s+") _RE_MODEL_HINT = re.compile(r"\d+") _RE_LEADING_PREFIX = re.compile(r"^([A-Z][a-z0-9]+)") _RE_ALPHANUM = re.compile(r"^([A-Za-z0-9]+)") _RE_NORMALIZE = re.compile(r"[^a-z0-9]+") def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: code = _strip_docstrings(code) cleaned = _RE_COMMENT.sub("", code) base = "\n".join( line for line in cleaned.splitlines() if line.strip() and not _RE_IMPORT.match(line) ) variants = set() if model_hint: variants.add(model_hint) variants.add(model_hint.replace("_", "")) variants.add(_RE_MODEL_HINT.sub("", model_hint)) if symbol_hint: match = _RE_LEADING_PREFIX.match(symbol_hint) or _RE_ALPHANUM.match(symbol_hint) prefix = match.group(1) if match else "" if prefix: variants.add(prefix) variants.add(prefix.replace("_", "")) variants.add(_RE_MODEL_HINT.sub("", prefix)) variants |= {variant.lower() for variant in list(variants)} sanitized = base for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True): sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE) return sanitized def _normalize(value: str | None) -> str: return _RE_NORMALIZE.sub("", value.lower()) if value else "" def _leading_prefix(name: str) -> str: match = _RE_LEADING_PREFIX.match(name) or _RE_ALPHANUM.match(name) return match.group(1) if match else "" def _infer_model_hint(definitions_kind: dict[str, str]) -> str | None: counts: dict[str, int] = {} for identifier, kind in definitions_kind.items(): if kind == "class": name = identifier elif kind == "method": name = identifier.split(".", 1)[0] else: continue prefix = _leading_prefix(name) if prefix: counts[prefix] = counts.get(prefix, 0) + 1 if not counts: return None return max(counts.items(), key=lambda item: item[1])[0] def _infer_model_prefixes(definitions_kind: dict[str, str]) -> set[str]: prefixes: set[str] = set() for identifier, kind in definitions_kind.items(): if kind == "class": name = identifier elif kind == "method": name = identifier.split(".", 1)[0] else: continue prefix = _leading_prefix(name) norm = _normalize(prefix) if norm: prefixes.add(norm) return prefixes def _calculate_reconstruction_score( contributors: Iterable[dict[str, object]], query_method_count: int, ) -> tuple[float, int]: if query_method_count <= 0: return 0.0, 0 best_scores: dict[str, float] = {} for contributor in contributors: query_name = contributor.get("query") if not query_name: continue score = float(contributor.get("score", 0.0)) if score > best_scores.get(str(query_name), 0.0): best_scores[str(query_name)] = score total_similarity = sum(best_scores.values()) return total_similarity / float(query_method_count), len(best_scores) def _normalize_source_path(path: str | None) -> str | None: if not path: return None raw = Path(path) if raw.is_absolute(): try: return str(raw.resolve().relative_to(MODELS_ROOT.resolve())) except Exception: return raw.as_posix() parts = raw.parts if "models" in parts: idx = parts.index("models") return "/".join(parts[idx + 1 :]) return raw.as_posix() def _strip_docstrings(source: str) -> str: try: tree = ast.parse(source) except SyntaxError: return source def strip_in_body(body: list[ast.stmt]) -> None: if not body: return first = body[0] if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant): if isinstance(first.value.value, str): body.pop(0) strip_in_body(tree.body) for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): strip_in_body(node.body) ast.fix_missing_locations(tree) try: return ast.unparse(tree) except Exception: return source def _strip_docstrings_in_tree(tree: ast.AST) -> None: def strip_in_body(body: list[ast.stmt]) -> None: if not body: return first = body[0] if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant): if isinstance(first.value.value, str): body.pop(0) strip_in_body(tree.body) for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): strip_in_body(node.body) def _sanitize_unparsed_code(code: str, model_hint: str | None, symbol_hint: str | None) -> str: cleaned = _RE_COMMENT.sub("", code) base = "\n".join(line for line in cleaned.splitlines() if line.strip() and not _RE_IMPORT.match(line)) hints: set[str] = set() if model_hint: hints.add(model_hint) hints.add(model_hint.replace("_", "")) hints.add(_RE_MODEL_HINT.sub("", model_hint)) if symbol_hint: match = _RE_LEADING_PREFIX.match(symbol_hint) or _RE_ALPHANUM.match(symbol_hint) prefix = match.group(1) if match else "" if prefix: hints.add(prefix) hints.add(prefix.replace("_", "")) hints.add(_RE_MODEL_HINT.sub("", prefix)) hints = {h for h in hints if len(h) >= 3} if hints: pattern = re.compile("|".join(re.escape(h) for h in sorted(hints, key=len, reverse=True)), re.IGNORECASE) base = pattern.sub("Model", base) return base def _normalize_code_for_compare(source: str) -> str: stripped = _strip_docstrings(source) return "".join(line.strip() for line in stripped.splitlines() if line.strip()) @cache def _load_source(relative_path: str) -> str: file_path = MODELS_ROOT / relative_path try: return file_path.read_text(encoding="utf-8") except (FileNotFoundError, OSError): return "" def _get_definition_segment(relative_path: str, match_name: str) -> str | None: source = _load_source(relative_path) if not source: return None try: tree = ast.parse(source) except SyntaxError: return None lines = source.splitlines() target = None if "." in match_name: class_name, method_name = match_name.split(".", 1) for node in tree.body: if isinstance(node, ast.ClassDef) and node.name == class_name: for child in node.body: if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == method_name: target = child break else: for node in tree.body: if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == match_name: target = node break if target is None: return None segment = ast.get_source_segment(source, target) if segment is None and hasattr(target, "lineno") and hasattr(target, "end_lineno"): start = max(0, target.lineno - 1) end = target.end_lineno segment = "\n".join(lines[start:end]) return segment @cache def _load_definition_line_map(relative_path: str) -> dict[str, int]: file_path = MODELS_ROOT / relative_path try: source = file_path.read_text(encoding="utf-8") except (FileNotFoundError, OSError): return {} try: tree = ast.parse(source) except SyntaxError: return {} line_map: dict[str, int] = {} for node in ast.iter_child_nodes(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): line_map[node.name] = getattr(node, "lineno", None) or 1 if isinstance(node, ast.ClassDef): for child in node.body: if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): key = f"{node.name}.{child.name}" line_map[key] = getattr(child, "lineno", None) or 1 return line_map def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, int | None]: full_path = (MODELS_ROOT / relative_path).resolve() line = _load_definition_line_map(relative_path).get(definition) return str(full_path), line class CodeSimilarityAnalyzer: def __init__(self, hub_dataset: str, precision: str = "float32", granularity: str = "method"): self.hub_dataset = hub_dataset self.precision = precision self.requested_granularity = granularity self.index_granularity = granularity self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL, trust_remote_code=True) self.model = AutoModel.from_pretrained( EMBEDDING_MODEL, trust_remote_code=True, dtype=self.dtype, device_map=None ).to(self.device) self.model.eval() self.index_dir: Path | None = None self.index_origin: str | None = None self.missing_files: tuple[str, ...] = () self._index_cache: dict[str, object] | None = None def _embedding_filename(self, granularity: str | None = None) -> str: granularity = granularity or self.index_granularity suffix = "" if granularity == "method": suffix += "_methods" if self.precision == "int8": suffix += "_int8" if not suffix: return "embeddings.safetensors" return f"embeddings{suffix}.safetensors" def _index_map_filename(self, granularity: str | None = None) -> str: granularity = granularity or self.index_granularity if granularity == "method": return "code_index_map_methods.json" return "code_index_map.json" def _resolve_index_path(self, filename: str) -> Path: if self.index_dir is None: return Path(filename) return self.index_dir / filename def _required_index_files(self, granularity: str | None = None) -> tuple[str, ...]: return ( self._embedding_filename(granularity), self._index_map_filename(granularity), ) def ensure_local_index(self) -> None: required_files = self._required_index_files(self.requested_granularity) if self.index_dir is not None and all((self.index_dir / fname).exists() for fname in required_files): return def missing_files(directory: Path, granularity: str) -> list[str]: return [fname for fname in self._required_index_files(granularity) if not (directory / fname).exists()] candidates: list[tuple[str, Path]] = [] env_dir = os.getenv("INDEX_DIR") if env_dir: candidates.append(("env", Path(env_dir))) candidates.append(("cwd", Path.cwd())) candidates.append(("repo", Path(__file__).resolve().parent.parent)) missing_preferred: list[str] = [] for origin, candidate in candidates: missing_preferred = missing_files(candidate, self.requested_granularity) if not missing_preferred: self.index_dir = candidate self.index_origin = origin self.index_granularity = self.requested_granularity self.missing_files = () self._index_cache = None return fallback_dir: Path | None = None fallback_origin: str | None = None fallback_missing: list[str] = [] if self.requested_granularity == "method": for origin, candidate in candidates: fallback_missing = missing_files(candidate, "definition") if not fallback_missing: fallback_dir = candidate fallback_origin = origin break snapshot_dir = Path(snapshot_download(repo_id=self.hub_dataset, repo_type="dataset")) hub_missing_preferred = missing_files(snapshot_dir, self.requested_granularity) hub_missing_fallback: list[str] = [] if self.requested_granularity == "method": hub_missing_fallback = missing_files(snapshot_dir, "definition") if not hub_missing_preferred: self.index_dir = snapshot_dir self.index_origin = "hub" self.index_granularity = self.requested_granularity self.missing_files = () self._index_cache = None return if self.requested_granularity == "method" and not hub_missing_fallback: self.index_dir = snapshot_dir self.index_origin = "hub" self.index_granularity = "definition" self.missing_files = tuple(hub_missing_preferred) self._index_cache = None return if fallback_dir is not None: self.index_dir = fallback_dir self.index_origin = fallback_origin self.index_granularity = "definition" self.missing_files = tuple(missing_preferred) self._index_cache = None return missing_detail = ", ".join(hub_missing_preferred or missing_preferred) raise FileNotFoundError( "Missing expected files for requested granularity; unable to fall back to definition index. " f"Missing: {missing_detail}" ) def _load_index(self) -> dict[str, object]: if self._index_cache is not None: return self._index_cache self.ensure_local_index() embedding_path = self._resolve_index_path(self._embedding_filename()) base = safetensors_load(str(embedding_path)) base_embeddings = base["embeddings"] scales = base.get("scale") if self.precision == "int8" else None with open(self._resolve_index_path(self._index_map_filename()), "r", encoding="utf-8") as file: identifier_map = {int(key): value for key, value in json.load(file).items()} self._index_cache = { "embeddings": base_embeddings, "scales": scales, "identifier_map": identifier_map, } return self._index_cache def _encode_batch(self, texts: list[str]) -> np.ndarray: encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") encoded = {key: value.to(self.device) for key, value in encoded.items()} with torch.no_grad(): if self.device.type == "cuda": with torch.autocast(device_type="cuda", dtype=self.dtype): output = self.model(**encoded) else: output = self.model(**encoded) if hasattr(output, "last_hidden_state"): embeddings = output.last_hidden_state mask = encoded["attention_mask"].unsqueeze(-1) embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9) elif hasattr(output, "pooler_output"): embeddings = output.pooler_output else: embeddings = output[0].mean(dim=1) embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) return embeddings.cpu().numpy().astype("float32") def encode( self, texts: list[str], progress: Callable[[int, int, str], None] | None = None, progress_offset: int = 0, progress_total: int | None = None, ) -> np.ndarray: if not texts: return np.zeros((0, 0), dtype="float32") output = [] for i in range(0, len(texts), BATCH_SIZE): output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) if progress and progress_total is not None: batch_index = i // BATCH_SIZE progress(progress_offset + batch_index + 1, progress_total, "encoding") if self.device.type == "cuda": torch.cuda.empty_cache() return np.vstack(output) if output else np.zeros((0, 0), dtype="float32") def _topk( self, query_embedding_row: np.ndarray, base_embeddings: np.ndarray, scales: np.ndarray | None, identifier_map: dict[int, str], self_model_normalized: set[str], self_name: str, exclude_same_file: str | None, k: int, pool_size: int | None = None, ) -> list[tuple[str, float]]: if self.precision == "int8": if scales is None: raise ValueError("Missing int8 scales for int8 search.") weighted_query = (query_embedding_row * scales).astype("float32") similarities = weighted_query @ base_embeddings.T.astype("float32") else: similarities = query_embedding_row @ base_embeddings.T pool = k + 32 if pool_size is None else max(k, pool_size) indices = np.argpartition(-similarities, pool)[:pool] indices = indices[np.argsort(-similarities[indices])] output = [] for match_id in indices: identifier = identifier_map[int(match_id)] if ":" not in identifier: continue relative_path, match_name = identifier.split(":", 1) if exclude_same_file and relative_path == exclude_same_file: continue if match_name == self_name: continue if self_model_normalized: parent_model = Path(relative_path).parts[0] if relative_path else "" if _normalize(parent_model) in self_model_normalized: continue output.append((identifier, float(similarities[match_id]))) if len(output) >= k: break return output def _extract_definitions_from_code( self, code: str, model_hint: str | None, granularity: str, ) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: definitions_raw: dict[str, str] = {} definitions_sanitized: dict[str, str] = {} definitions_kind: dict[str, str] = {} lines = code.splitlines() tree = ast.parse(code) entries: list[tuple[str, str, ast.AST, ast.ClassDef | None]] = [] for node in ast.iter_child_nodes(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and granularity in ("definition", "method"): segment = ast.get_source_segment(code, node) if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): start = max(0, node.lineno - 1) end = node.end_lineno segment = "\n".join(lines[start:end]) if not segment: continue identifier = node.name definitions_raw[identifier] = segment definitions_kind[identifier] = "function" entries.append((identifier, "function", node, None)) continue if isinstance(node, ast.ClassDef): class_segment = ast.get_source_segment(code, node) if class_segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): start = max(0, node.lineno - 1) end = node.end_lineno class_segment = "\n".join(lines[start:end]) if not class_segment: continue if granularity == "definition": identifier = node.name definitions_raw[identifier] = class_segment definitions_kind[identifier] = "class" entries.append((identifier, "class", node, None)) continue for child in node.body: if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): continue segment = ast.get_source_segment(code, child) if segment is None and hasattr(child, "lineno") and hasattr(child, "end_lineno"): start = max(0, child.lineno - 1) end = child.end_lineno segment = "\n".join(lines[start:end]) if not segment: continue method_name = child.name identifier = f"{node.name}.{method_name}" definitions_raw[identifier] = segment definitions_kind[identifier] = "method" entries.append((identifier, "method", child, node)) _strip_docstrings_in_tree(tree) for identifier, kind, node, parent in entries: try: if kind == "method" and parent is not None: parent_header = ast.unparse(parent).splitlines()[0] combined = f"{parent_header}\n{ast.unparse(node)}" sanitized = _sanitize_unparsed_code(combined, model_hint, parent.name) else: sanitized = _sanitize_unparsed_code(ast.unparse(node), model_hint, identifier) except Exception: sanitized = definitions_raw.get(identifier, "") definitions_sanitized[identifier] = sanitized return definitions_raw, definitions_sanitized, definitions_kind def analyze_code( self, code: str, top_k_per_item: int = 5, model_hint: str | None = None, exclude_same_model: bool = True, exclude_identical: bool = True, exclude_models: list[str] | None = None, source_path: str | None = None, progress: Callable[[int, int, str], None] | None = None, ) -> dict[str, dict[str, object]]: if progress: progress(0, 1, "starting") index_data = self._load_index() base_embeddings = index_data["embeddings"] scales = index_data["scales"] identifier_map = index_data["identifier_map"] definitions_raw, definitions_sanitized, definitions_kind = self._extract_definitions_from_code( code, model_hint, self.index_granularity ) if model_hint is None: model_hint = _infer_model_hint(definitions_kind) self_model_normalized = _infer_model_prefixes(definitions_kind) if exclude_same_model else set() if exclude_same_model and model_hint: self_model_normalized.add(_normalize(model_hint)) if exclude_models: self_model_normalized.update({_normalize(name) for name in exclude_models}) exclude_same_file = _normalize_source_path(source_path) query_identifiers = list(definitions_raw.keys()) query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers] query_compare = { key: _normalize_code_for_compare(definitions_raw[key]) for key in query_identifiers if key in definitions_raw } total_batches = max(1, math.ceil(len(query_sources_sanitized) / BATCH_SIZE)) total_steps = 2 + total_batches if progress: progress(1, total_steps, "parsed") query_embeddings = self.encode( query_sources_sanitized, progress=progress, progress_offset=1, progress_total=total_steps, ) output = {} output_all = {} identical_filtered = 0 for i, query_identifier in enumerate(query_identifiers): pool_size = max(top_k_per_item * 5, top_k_per_item + 32) candidates = self._topk( query_embeddings[i], base_embeddings, scales, identifier_map, self_model_normalized, query_identifier, exclude_same_file, pool_size, pool_size=pool_size, ) entry: dict[str, object] = { "kind": definitions_kind.get(query_identifier, "function"), "embedding": [], } entry_all: dict[str, object] = { "kind": definitions_kind.get(query_identifier, "function"), "embedding": [], } for identifier, score in candidates: relative_path, match_name = identifier.split(":", 1) is_identical = False match_segment = None if exclude_identical or query_compare: match_segment = _get_definition_segment(relative_path, match_name) if match_segment is not None: match_norm = _normalize_code_for_compare(match_segment) query_norm = query_compare.get(query_identifier) if query_norm and match_norm == query_norm: is_identical = True if len(entry_all["embedding"]) < top_k_per_item: full_path, line = _resolve_definition_location(relative_path, match_name) entry_all["embedding"].append( { "identifier": identifier, "relative_path": relative_path, "match_name": match_name, "score": score, "full_path": full_path, "line": line, "is_identical": is_identical, } ) if exclude_identical and is_identical: identical_filtered += 1 continue full_path, line = _resolve_definition_location(relative_path, match_name) entry["embedding"].append( { "identifier": identifier, "relative_path": relative_path, "match_name": match_name, "score": score, "full_path": full_path, "line": line, "is_identical": is_identical, } ) if len(entry["embedding"]) >= top_k_per_item and len(entry_all["embedding"]) >= top_k_per_item: break output[query_identifier] = entry output_all[query_identifier] = entry_all def build_by_class(result_map: dict[str, dict[str, object]]) -> dict[str, list[dict[str, object]]]: def candidate_class_key(relative_path: str, match_name: str) -> tuple[str, str]: class_name = match_name.split(".", 1)[0] if "." in match_name else "" return relative_path, class_name def query_class_key(query_identifier: str, kind: str) -> str: return query_identifier.split(".", 1)[0] if kind == "method" else query_identifier by_class: dict[str, dict[tuple[str, str], dict[str, object]]] = {} for query_identifier, entry in result_map.items(): kind = entry.get("kind", "function") if kind == "function": matches = entry.get("embedding", []) if not matches: continue best_per_cand: dict[tuple[str, str], dict[str, object]] = {} for match in matches: rel = match.get("relative_path") mname = match.get("match_name") score = match.get("score") if rel is None or not mname or score is None: continue if "." in mname: continue ckey = (rel, mname) prev = best_per_cand.get(ckey) if prev is None or float(score) > float(prev.get("score", -1.0)): best_per_cand[ckey] = match for ckey, match in best_per_cand.items(): slot = by_class.setdefault(query_identifier, {}).setdefault( ckey, { "relative_path": ckey[0], "class_name": ckey[1], "scores": [], "contributors": [], }, ) slot["scores"].append(float(match["score"])) slot["contributors"].append( {"query": query_identifier, "match": match["identifier"], "score": float(match["score"])} ) continue qcls = query_class_key(query_identifier, kind) matches = entry.get("embedding", []) if not matches: continue best_per_cand: dict[tuple[str, str], dict[str, object]] = {} for match in matches: rel = match.get("relative_path") mname = match.get("match_name") score = match.get("score") if rel is None or not mname or score is None: continue ckey = candidate_class_key(rel, mname) prev = best_per_cand.get(ckey) if prev is None or float(score) > float(prev.get("score", -1.0)): best_per_cand[ckey] = match for ckey, match in best_per_cand.items(): slot = by_class.setdefault(qcls, {}).setdefault( ckey, { "relative_path": ckey[0], "class_name": ckey[1], "scores": [], "contributors": [], }, ) slot["scores"].append(float(match["score"])) slot["contributors"].append( {"query": query_identifier, "match": match["identifier"], "score": float(match["score"])} ) by_class_out: dict[str, list[dict[str, object]]] = {} for qcls, cand_map in by_class.items(): q_method_count = len( [ key for key, kind in definitions_kind.items() if kind == "method" and key.startswith(f"{qcls}.") and key.split(".")[-1] not in BOILERPLATE_NAMES ] ) q_method_count = max(1, q_method_count) rows = [] for _, slot in cand_map.items(): filtered_contributors = [ item for item in slot["contributors"] if str(item.get("query", "")).split(".")[-1] not in BOILERPLATE_NAMES ] base_score, coverage_count = _calculate_reconstruction_score( filtered_contributors, q_method_count ) coverage_ratio = coverage_count / float(q_method_count) contributors = sorted(slot["contributors"], key=lambda x: float(x["score"]), reverse=True)[:5] rows.append( { "relative_path": slot["relative_path"], "class_name": slot["class_name"], "identifier": f"{slot['relative_path']}:{slot['class_name']}", "score": base_score, "coverage": coverage_count, "coverage_pct": coverage_ratio, "top_contributors": contributors, } ) rows.sort(key=lambda row: (float(row["score"]), int(row["coverage"])), reverse=True) by_class_out[qcls] = rows[:10] return by_class_out def build_overall(result_map: dict[str, dict[str, object]]) -> list[dict[str, object]]: per_query_best: list[dict[str, float]] = [] for query_identifier, data in result_map.items(): if query_identifier.split(".")[-1] in BOILERPLATE_NAMES: continue best_by_file: dict[str, float] = {} for match in data.get("embedding", []): rel = match.get("relative_path") score = match.get("score") if rel is None or score is None: continue best_by_file[rel] = max(best_by_file.get(rel, -1.0), float(score)) per_query_best.append(best_by_file) aggregate: dict[str, dict[str, object]] = {} for best_by_file in per_query_best: for rel, score in best_by_file.items(): slot = aggregate.setdefault(rel, {"relative_path": rel, "scores": []}) slot["scores"].append(score) overall = [] for rel, slot in aggregate.items(): scores = sorted(slot["scores"], reverse=True) overall.append( { "relative_path": rel, "score": float(sum(scores) / max(1, len(scores))), "count": len(scores), "best_score": float(scores[0]) if scores else -1.0, } ) overall.sort(key=lambda item: float(item["score"]), reverse=True) return overall by_class_out = build_by_class(output) by_class_all = build_by_class(output_all) overall = build_overall(output) overall_all = build_overall(output_all) if progress: progress(total_steps, total_steps, "done") return { "results": output, "overall": overall, "by_class": by_class_out, "identical_filtered": identical_filtered, "by_class_all": by_class_all, "overall_all": overall_all, } def index_status(self) -> dict[str, object]: return { "requested_granularity": self.requested_granularity, "resolved_granularity": self.index_granularity, "precision": self.precision, "hub_dataset": self.hub_dataset, "index_dir": str(self.index_dir) if self.index_dir else None, "index_origin": self.index_origin, "missing_files": list(self.missing_files), "embedding_model": EMBEDDING_MODEL, } @staticmethod def list_models() -> list[str]: models = [] for path in MODELS_ROOT.iterdir(): if path.is_dir() and not path.name.startswith("__"): models.append(path.name) return sorted(models) def generate_agent_report(analysis_results: dict) -> str: lines = [ "### MODULAR RECONSTRUCTION SPECIFICATION (FOR AGENT)", "You are operating inside the Transformers repository. Your task is to produce a `modular_*.py` file that", "maximizes reuse of existing Transformers components.", "Purpose: the `modular_model_converter.py` script reads the modular file, resolves inheritance, and", "generates the final `modeling_*.py` file. The modular file should prefer inheriting from existing", "Transformers classes (e.g., attention, MLP, blocks, model heads) rather than starting fresh from", "`nn.Module`, so the generated model shares as much code as possible and only overrides the minimal deltas.", "Use this report to choose those base classes and to decide where custom logic is actually required.", "", ] overall = analysis_results.get("overall", []) if overall: lines.append(f"Top architectural base: `{overall[0]['relative_path']}`") lines.append("\n#### 1. REUSE PLAN (INHERITANCE)") by_class = analysis_results.get("by_class", {}) if not by_class: lines.append("- (none)") for q_cls, matches in by_class.items(): if not matches: continue top = matches[0] if float(top.get("score", 0.0)) >= 0.85: lines.append( f"- **{q_cls}**: inherit from `{top['class_name']}` in `{top['relative_path']}`" ) lines.append( f" Evidence: {top['coverage']} methods matched (~{top['score']:.2f} similarity)." ) else: lines.append(f"- **{q_cls}**: custom implementation (no strong base class match).") lines.append("\n#### 2. DATAFLOW / STRUCTURAL ANCHORS") results = analysis_results.get("results", {}) if not results: lines.append("- (none)") for query_name, entry in results.items(): if query_name.split(".")[-1] in BOILERPLATE_NAMES: continue matches = entry.get("embedding", []) if matches and float(matches[0].get("score", 0.0)) >= 0.92: top = matches[0] lines.append( f"- `{query_name}` is near-identical to `{top['match_name']}` in `{top['relative_path']}`." ) lines.append(" Read both methods and reuse as much as possible through modular inheritance.") lines.append("\n#### 3. PROMPT STARTER") lines.append("```text") lines.append("I am building a modular transformers file. Based on similarity analysis:") lines.append(f"Base model: {overall[0]['relative_path'] if overall else 'Unknown'}") lines.append("Implement the modular classes using # Copied from tags where possible.") lines.append("```") return "\n".join(lines) def get_default_hub_dataset() -> str: return os.getenv("HUB_DATASET", HUB_DATASET_DEFAULT)