modular-detector-v2 / app /detector.py
Molbap's picture
Molbap HF Staff
Update app with better diff, new style
4fe7080 verified
import ast
import json
import math
import os
import re
from typing import Callable, Iterable
from functools import cache
from pathlib import Path
import numpy as np
from huggingface_hub import snapshot_download
from safetensors.numpy import load_file as safetensors_load
import torch
from transformers import AutoModel, AutoTokenizer
import transformers
_LIB_PATH = Path(transformers.__file__).resolve().parent
_ENV_REPO = os.getenv("TRANSFORMERS_REPO")
if _ENV_REPO:
_env_path = Path(_ENV_REPO)
_candidate = _env_path / "src" / "transformers" / "models"
if _candidate.exists():
MODELS_ROOT = _candidate
else:
_fallback = _env_path / "models"
MODELS_ROOT = _fallback if _fallback.exists() else _LIB_PATH / "models"
else:
MODELS_ROOT = _LIB_PATH / "models"
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
BATCH_SIZE = 16
MAX_LENGTH = 4096
HUB_DATASET_DEFAULT = "Molbap/modular-detector-embeddings"
BOILERPLATE_NAMES = {
"__init__",
"_init_weights",
"__repr__",
"extra_repr",
"get_input_embeddings",
"set_input_embeddings",
"get_output_embeddings",
"set_output_embeddings",
"tie_weights",
"post_init",
"forward",
"init_weights",
"reset_parameters",
"training",
}
_RE_COMMENT = re.compile(r"#.*")
_RE_IMPORT = re.compile(r"\s*(from|import)\s+")
_RE_MODEL_HINT = re.compile(r"\d+")
_RE_LEADING_PREFIX = re.compile(r"^([A-Z][a-z0-9]+)")
_RE_ALPHANUM = re.compile(r"^([A-Za-z0-9]+)")
_RE_NORMALIZE = re.compile(r"[^a-z0-9]+")
def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str:
code = _strip_docstrings(code)
cleaned = _RE_COMMENT.sub("", code)
base = "\n".join(
line for line in cleaned.splitlines() if line.strip() and not _RE_IMPORT.match(line)
)
variants = set()
if model_hint:
variants.add(model_hint)
variants.add(model_hint.replace("_", ""))
variants.add(_RE_MODEL_HINT.sub("", model_hint))
if symbol_hint:
match = _RE_LEADING_PREFIX.match(symbol_hint) or _RE_ALPHANUM.match(symbol_hint)
prefix = match.group(1) if match else ""
if prefix:
variants.add(prefix)
variants.add(prefix.replace("_", ""))
variants.add(_RE_MODEL_HINT.sub("", prefix))
variants |= {variant.lower() for variant in list(variants)}
sanitized = base
for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True):
sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE)
return sanitized
def _normalize(value: str | None) -> str:
return _RE_NORMALIZE.sub("", value.lower()) if value else ""
def _leading_prefix(name: str) -> str:
match = _RE_LEADING_PREFIX.match(name) or _RE_ALPHANUM.match(name)
return match.group(1) if match else ""
def _infer_model_hint(definitions_kind: dict[str, str]) -> str | None:
counts: dict[str, int] = {}
for identifier, kind in definitions_kind.items():
if kind == "class":
name = identifier
elif kind == "method":
name = identifier.split(".", 1)[0]
else:
continue
prefix = _leading_prefix(name)
if prefix:
counts[prefix] = counts.get(prefix, 0) + 1
if not counts:
return None
return max(counts.items(), key=lambda item: item[1])[0]
def _infer_model_prefixes(definitions_kind: dict[str, str]) -> set[str]:
prefixes: set[str] = set()
for identifier, kind in definitions_kind.items():
if kind == "class":
name = identifier
elif kind == "method":
name = identifier.split(".", 1)[0]
else:
continue
prefix = _leading_prefix(name)
norm = _normalize(prefix)
if norm:
prefixes.add(norm)
return prefixes
def _calculate_reconstruction_score(
contributors: Iterable[dict[str, object]],
query_method_count: int,
) -> tuple[float, int]:
if query_method_count <= 0:
return 0.0, 0
best_scores: dict[str, float] = {}
for contributor in contributors:
query_name = contributor.get("query")
if not query_name:
continue
score = float(contributor.get("score", 0.0))
if score > best_scores.get(str(query_name), 0.0):
best_scores[str(query_name)] = score
total_similarity = sum(best_scores.values())
return total_similarity / float(query_method_count), len(best_scores)
def _normalize_source_path(path: str | None) -> str | None:
if not path:
return None
raw = Path(path)
if raw.is_absolute():
try:
return str(raw.resolve().relative_to(MODELS_ROOT.resolve()))
except Exception:
return raw.as_posix()
parts = raw.parts
if "models" in parts:
idx = parts.index("models")
return "/".join(parts[idx + 1 :])
return raw.as_posix()
def _strip_docstrings(source: str) -> str:
try:
tree = ast.parse(source)
except SyntaxError:
return source
def strip_in_body(body: list[ast.stmt]) -> None:
if not body:
return
first = body[0]
if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant):
if isinstance(first.value.value, str):
body.pop(0)
strip_in_body(tree.body)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
strip_in_body(node.body)
ast.fix_missing_locations(tree)
try:
return ast.unparse(tree)
except Exception:
return source
def _strip_docstrings_in_tree(tree: ast.AST) -> None:
def strip_in_body(body: list[ast.stmt]) -> None:
if not body:
return
first = body[0]
if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant):
if isinstance(first.value.value, str):
body.pop(0)
strip_in_body(tree.body)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
strip_in_body(node.body)
def _sanitize_unparsed_code(code: str, model_hint: str | None, symbol_hint: str | None) -> str:
cleaned = _RE_COMMENT.sub("", code)
base = "\n".join(line for line in cleaned.splitlines() if line.strip() and not _RE_IMPORT.match(line))
hints: set[str] = set()
if model_hint:
hints.add(model_hint)
hints.add(model_hint.replace("_", ""))
hints.add(_RE_MODEL_HINT.sub("", model_hint))
if symbol_hint:
match = _RE_LEADING_PREFIX.match(symbol_hint) or _RE_ALPHANUM.match(symbol_hint)
prefix = match.group(1) if match else ""
if prefix:
hints.add(prefix)
hints.add(prefix.replace("_", ""))
hints.add(_RE_MODEL_HINT.sub("", prefix))
hints = {h for h in hints if len(h) >= 3}
if hints:
pattern = re.compile("|".join(re.escape(h) for h in sorted(hints, key=len, reverse=True)), re.IGNORECASE)
base = pattern.sub("Model", base)
return base
def _normalize_code_for_compare(source: str) -> str:
stripped = _strip_docstrings(source)
return "".join(line.strip() for line in stripped.splitlines() if line.strip())
@cache
def _load_source(relative_path: str) -> str:
file_path = MODELS_ROOT / relative_path
try:
return file_path.read_text(encoding="utf-8")
except (FileNotFoundError, OSError):
return ""
def _get_definition_segment(relative_path: str, match_name: str) -> str | None:
source = _load_source(relative_path)
if not source:
return None
try:
tree = ast.parse(source)
except SyntaxError:
return None
lines = source.splitlines()
target = None
if "." in match_name:
class_name, method_name = match_name.split(".", 1)
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
for child in node.body:
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == method_name:
target = child
break
else:
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == match_name:
target = node
break
if target is None:
return None
segment = ast.get_source_segment(source, target)
if segment is None and hasattr(target, "lineno") and hasattr(target, "end_lineno"):
start = max(0, target.lineno - 1)
end = target.end_lineno
segment = "\n".join(lines[start:end])
return segment
@cache
def _load_definition_line_map(relative_path: str) -> dict[str, int]:
file_path = MODELS_ROOT / relative_path
try:
source = file_path.read_text(encoding="utf-8")
except (FileNotFoundError, OSError):
return {}
try:
tree = ast.parse(source)
except SyntaxError:
return {}
line_map: dict[str, int] = {}
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
line_map[node.name] = getattr(node, "lineno", None) or 1
if isinstance(node, ast.ClassDef):
for child in node.body:
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
key = f"{node.name}.{child.name}"
line_map[key] = getattr(child, "lineno", None) or 1
return line_map
def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, int | None]:
full_path = (MODELS_ROOT / relative_path).resolve()
line = _load_definition_line_map(relative_path).get(definition)
return str(full_path), line
class CodeSimilarityAnalyzer:
def __init__(self, hub_dataset: str, precision: str = "float32", granularity: str = "method"):
self.hub_dataset = hub_dataset
self.precision = precision
self.requested_granularity = granularity
self.index_granularity = granularity
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL, trust_remote_code=True)
self.model = AutoModel.from_pretrained(
EMBEDDING_MODEL, trust_remote_code=True, dtype=self.dtype, device_map=None
).to(self.device)
self.model.eval()
self.index_dir: Path | None = None
self.index_origin: str | None = None
self.missing_files: tuple[str, ...] = ()
self._index_cache: dict[str, object] | None = None
def _embedding_filename(self, granularity: str | None = None) -> str:
granularity = granularity or self.index_granularity
suffix = ""
if granularity == "method":
suffix += "_methods"
if self.precision == "int8":
suffix += "_int8"
if not suffix:
return "embeddings.safetensors"
return f"embeddings{suffix}.safetensors"
def _index_map_filename(self, granularity: str | None = None) -> str:
granularity = granularity or self.index_granularity
if granularity == "method":
return "code_index_map_methods.json"
return "code_index_map.json"
def _resolve_index_path(self, filename: str) -> Path:
if self.index_dir is None:
return Path(filename)
return self.index_dir / filename
def _required_index_files(self, granularity: str | None = None) -> tuple[str, ...]:
return (
self._embedding_filename(granularity),
self._index_map_filename(granularity),
)
def ensure_local_index(self) -> None:
required_files = self._required_index_files(self.requested_granularity)
if self.index_dir is not None and all((self.index_dir / fname).exists() for fname in required_files):
return
def missing_files(directory: Path, granularity: str) -> list[str]:
return [fname for fname in self._required_index_files(granularity) if not (directory / fname).exists()]
candidates: list[tuple[str, Path]] = []
env_dir = os.getenv("INDEX_DIR")
if env_dir:
candidates.append(("env", Path(env_dir)))
candidates.append(("cwd", Path.cwd()))
candidates.append(("repo", Path(__file__).resolve().parent.parent))
missing_preferred: list[str] = []
for origin, candidate in candidates:
missing_preferred = missing_files(candidate, self.requested_granularity)
if not missing_preferred:
self.index_dir = candidate
self.index_origin = origin
self.index_granularity = self.requested_granularity
self.missing_files = ()
self._index_cache = None
return
fallback_dir: Path | None = None
fallback_origin: str | None = None
fallback_missing: list[str] = []
if self.requested_granularity == "method":
for origin, candidate in candidates:
fallback_missing = missing_files(candidate, "definition")
if not fallback_missing:
fallback_dir = candidate
fallback_origin = origin
break
snapshot_dir = Path(snapshot_download(repo_id=self.hub_dataset, repo_type="dataset"))
hub_missing_preferred = missing_files(snapshot_dir, self.requested_granularity)
hub_missing_fallback: list[str] = []
if self.requested_granularity == "method":
hub_missing_fallback = missing_files(snapshot_dir, "definition")
if not hub_missing_preferred:
self.index_dir = snapshot_dir
self.index_origin = "hub"
self.index_granularity = self.requested_granularity
self.missing_files = ()
self._index_cache = None
return
if self.requested_granularity == "method" and not hub_missing_fallback:
self.index_dir = snapshot_dir
self.index_origin = "hub"
self.index_granularity = "definition"
self.missing_files = tuple(hub_missing_preferred)
self._index_cache = None
return
if fallback_dir is not None:
self.index_dir = fallback_dir
self.index_origin = fallback_origin
self.index_granularity = "definition"
self.missing_files = tuple(missing_preferred)
self._index_cache = None
return
missing_detail = ", ".join(hub_missing_preferred or missing_preferred)
raise FileNotFoundError(
"Missing expected files for requested granularity; unable to fall back to definition index. "
f"Missing: {missing_detail}"
)
def _load_index(self) -> dict[str, object]:
if self._index_cache is not None:
return self._index_cache
self.ensure_local_index()
embedding_path = self._resolve_index_path(self._embedding_filename())
base = safetensors_load(str(embedding_path))
base_embeddings = base["embeddings"]
scales = base.get("scale") if self.precision == "int8" else None
with open(self._resolve_index_path(self._index_map_filename()), "r", encoding="utf-8") as file:
identifier_map = {int(key): value for key, value in json.load(file).items()}
self._index_cache = {
"embeddings": base_embeddings,
"scales": scales,
"identifier_map": identifier_map,
}
return self._index_cache
def _encode_batch(self, texts: list[str]) -> np.ndarray:
encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
encoded = {key: value.to(self.device) for key, value in encoded.items()}
with torch.no_grad():
if self.device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=self.dtype):
output = self.model(**encoded)
else:
output = self.model(**encoded)
if hasattr(output, "last_hidden_state"):
embeddings = output.last_hidden_state
mask = encoded["attention_mask"].unsqueeze(-1)
embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)
elif hasattr(output, "pooler_output"):
embeddings = output.pooler_output
else:
embeddings = output[0].mean(dim=1)
embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1)
return embeddings.cpu().numpy().astype("float32")
def encode(
self,
texts: list[str],
progress: Callable[[int, int, str], None] | None = None,
progress_offset: int = 0,
progress_total: int | None = None,
) -> np.ndarray:
if not texts:
return np.zeros((0, 0), dtype="float32")
output = []
for i in range(0, len(texts), BATCH_SIZE):
output.append(self._encode_batch(texts[i : i + BATCH_SIZE]))
if progress and progress_total is not None:
batch_index = i // BATCH_SIZE
progress(progress_offset + batch_index + 1, progress_total, "encoding")
if self.device.type == "cuda":
torch.cuda.empty_cache()
return np.vstack(output) if output else np.zeros((0, 0), dtype="float32")
def _topk(
self,
query_embedding_row: np.ndarray,
base_embeddings: np.ndarray,
scales: np.ndarray | None,
identifier_map: dict[int, str],
self_model_normalized: set[str],
self_name: str,
exclude_same_file: str | None,
k: int,
pool_size: int | None = None,
) -> list[tuple[str, float]]:
if self.precision == "int8":
if scales is None:
raise ValueError("Missing int8 scales for int8 search.")
weighted_query = (query_embedding_row * scales).astype("float32")
similarities = weighted_query @ base_embeddings.T.astype("float32")
else:
similarities = query_embedding_row @ base_embeddings.T
pool = k + 32 if pool_size is None else max(k, pool_size)
indices = np.argpartition(-similarities, pool)[:pool]
indices = indices[np.argsort(-similarities[indices])]
output = []
for match_id in indices:
identifier = identifier_map[int(match_id)]
if ":" not in identifier:
continue
relative_path, match_name = identifier.split(":", 1)
if exclude_same_file and relative_path == exclude_same_file:
continue
if match_name == self_name:
continue
if self_model_normalized:
parent_model = Path(relative_path).parts[0] if relative_path else ""
if _normalize(parent_model) in self_model_normalized:
continue
output.append((identifier, float(similarities[match_id])))
if len(output) >= k:
break
return output
def _extract_definitions_from_code(
self,
code: str,
model_hint: str | None,
granularity: str,
) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
definitions_raw: dict[str, str] = {}
definitions_sanitized: dict[str, str] = {}
definitions_kind: dict[str, str] = {}
lines = code.splitlines()
tree = ast.parse(code)
entries: list[tuple[str, str, ast.AST, ast.ClassDef | None]] = []
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and granularity in ("definition", "method"):
segment = ast.get_source_segment(code, node)
if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
start = max(0, node.lineno - 1)
end = node.end_lineno
segment = "\n".join(lines[start:end])
if not segment:
continue
identifier = node.name
definitions_raw[identifier] = segment
definitions_kind[identifier] = "function"
entries.append((identifier, "function", node, None))
continue
if isinstance(node, ast.ClassDef):
class_segment = ast.get_source_segment(code, node)
if class_segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
start = max(0, node.lineno - 1)
end = node.end_lineno
class_segment = "\n".join(lines[start:end])
if not class_segment:
continue
if granularity == "definition":
identifier = node.name
definitions_raw[identifier] = class_segment
definitions_kind[identifier] = "class"
entries.append((identifier, "class", node, None))
continue
for child in node.body:
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
segment = ast.get_source_segment(code, child)
if segment is None and hasattr(child, "lineno") and hasattr(child, "end_lineno"):
start = max(0, child.lineno - 1)
end = child.end_lineno
segment = "\n".join(lines[start:end])
if not segment:
continue
method_name = child.name
identifier = f"{node.name}.{method_name}"
definitions_raw[identifier] = segment
definitions_kind[identifier] = "method"
entries.append((identifier, "method", child, node))
_strip_docstrings_in_tree(tree)
for identifier, kind, node, parent in entries:
try:
if kind == "method" and parent is not None:
parent_header = ast.unparse(parent).splitlines()[0]
combined = f"{parent_header}\n{ast.unparse(node)}"
sanitized = _sanitize_unparsed_code(combined, model_hint, parent.name)
else:
sanitized = _sanitize_unparsed_code(ast.unparse(node), model_hint, identifier)
except Exception:
sanitized = definitions_raw.get(identifier, "")
definitions_sanitized[identifier] = sanitized
return definitions_raw, definitions_sanitized, definitions_kind
def analyze_code(
self,
code: str,
top_k_per_item: int = 5,
model_hint: str | None = None,
exclude_same_model: bool = True,
exclude_identical: bool = True,
exclude_models: list[str] | None = None,
source_path: str | None = None,
progress: Callable[[int, int, str], None] | None = None,
) -> dict[str, dict[str, object]]:
if progress:
progress(0, 1, "starting")
index_data = self._load_index()
base_embeddings = index_data["embeddings"]
scales = index_data["scales"]
identifier_map = index_data["identifier_map"]
definitions_raw, definitions_sanitized, definitions_kind = self._extract_definitions_from_code(
code, model_hint, self.index_granularity
)
if model_hint is None:
model_hint = _infer_model_hint(definitions_kind)
self_model_normalized = _infer_model_prefixes(definitions_kind) if exclude_same_model else set()
if exclude_same_model and model_hint:
self_model_normalized.add(_normalize(model_hint))
if exclude_models:
self_model_normalized.update({_normalize(name) for name in exclude_models})
exclude_same_file = _normalize_source_path(source_path)
query_identifiers = list(definitions_raw.keys())
query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers]
query_compare = {
key: _normalize_code_for_compare(definitions_raw[key]) for key in query_identifiers if key in definitions_raw
}
total_batches = max(1, math.ceil(len(query_sources_sanitized) / BATCH_SIZE))
total_steps = 2 + total_batches
if progress:
progress(1, total_steps, "parsed")
query_embeddings = self.encode(
query_sources_sanitized,
progress=progress,
progress_offset=1,
progress_total=total_steps,
)
output = {}
output_all = {}
identical_filtered = 0
for i, query_identifier in enumerate(query_identifiers):
pool_size = max(top_k_per_item * 5, top_k_per_item + 32)
candidates = self._topk(
query_embeddings[i],
base_embeddings,
scales,
identifier_map,
self_model_normalized,
query_identifier,
exclude_same_file,
pool_size,
pool_size=pool_size,
)
entry: dict[str, object] = {
"kind": definitions_kind.get(query_identifier, "function"),
"embedding": [],
}
entry_all: dict[str, object] = {
"kind": definitions_kind.get(query_identifier, "function"),
"embedding": [],
}
for identifier, score in candidates:
relative_path, match_name = identifier.split(":", 1)
is_identical = False
match_segment = None
if exclude_identical or query_compare:
match_segment = _get_definition_segment(relative_path, match_name)
if match_segment is not None:
match_norm = _normalize_code_for_compare(match_segment)
query_norm = query_compare.get(query_identifier)
if query_norm and match_norm == query_norm:
is_identical = True
if len(entry_all["embedding"]) < top_k_per_item:
full_path, line = _resolve_definition_location(relative_path, match_name)
entry_all["embedding"].append(
{
"identifier": identifier,
"relative_path": relative_path,
"match_name": match_name,
"score": score,
"full_path": full_path,
"line": line,
"is_identical": is_identical,
}
)
if exclude_identical and is_identical:
identical_filtered += 1
continue
full_path, line = _resolve_definition_location(relative_path, match_name)
entry["embedding"].append(
{
"identifier": identifier,
"relative_path": relative_path,
"match_name": match_name,
"score": score,
"full_path": full_path,
"line": line,
"is_identical": is_identical,
}
)
if len(entry["embedding"]) >= top_k_per_item and len(entry_all["embedding"]) >= top_k_per_item:
break
output[query_identifier] = entry
output_all[query_identifier] = entry_all
def build_by_class(result_map: dict[str, dict[str, object]]) -> dict[str, list[dict[str, object]]]:
def candidate_class_key(relative_path: str, match_name: str) -> tuple[str, str]:
class_name = match_name.split(".", 1)[0] if "." in match_name else "<module>"
return relative_path, class_name
def query_class_key(query_identifier: str, kind: str) -> str:
return query_identifier.split(".", 1)[0] if kind == "method" else query_identifier
by_class: dict[str, dict[tuple[str, str], dict[str, object]]] = {}
for query_identifier, entry in result_map.items():
kind = entry.get("kind", "function")
if kind == "function":
matches = entry.get("embedding", [])
if not matches:
continue
best_per_cand: dict[tuple[str, str], dict[str, object]] = {}
for match in matches:
rel = match.get("relative_path")
mname = match.get("match_name")
score = match.get("score")
if rel is None or not mname or score is None:
continue
if "." in mname:
continue
ckey = (rel, mname)
prev = best_per_cand.get(ckey)
if prev is None or float(score) > float(prev.get("score", -1.0)):
best_per_cand[ckey] = match
for ckey, match in best_per_cand.items():
slot = by_class.setdefault(query_identifier, {}).setdefault(
ckey,
{
"relative_path": ckey[0],
"class_name": ckey[1],
"scores": [],
"contributors": [],
},
)
slot["scores"].append(float(match["score"]))
slot["contributors"].append(
{"query": query_identifier, "match": match["identifier"], "score": float(match["score"])}
)
continue
qcls = query_class_key(query_identifier, kind)
matches = entry.get("embedding", [])
if not matches:
continue
best_per_cand: dict[tuple[str, str], dict[str, object]] = {}
for match in matches:
rel = match.get("relative_path")
mname = match.get("match_name")
score = match.get("score")
if rel is None or not mname or score is None:
continue
ckey = candidate_class_key(rel, mname)
prev = best_per_cand.get(ckey)
if prev is None or float(score) > float(prev.get("score", -1.0)):
best_per_cand[ckey] = match
for ckey, match in best_per_cand.items():
slot = by_class.setdefault(qcls, {}).setdefault(
ckey,
{
"relative_path": ckey[0],
"class_name": ckey[1],
"scores": [],
"contributors": [],
},
)
slot["scores"].append(float(match["score"]))
slot["contributors"].append(
{"query": query_identifier, "match": match["identifier"], "score": float(match["score"])}
)
by_class_out: dict[str, list[dict[str, object]]] = {}
for qcls, cand_map in by_class.items():
q_method_count = len(
[
key
for key, kind in definitions_kind.items()
if kind == "method"
and key.startswith(f"{qcls}.")
and key.split(".")[-1] not in BOILERPLATE_NAMES
]
)
q_method_count = max(1, q_method_count)
rows = []
for _, slot in cand_map.items():
filtered_contributors = [
item
for item in slot["contributors"]
if str(item.get("query", "")).split(".")[-1] not in BOILERPLATE_NAMES
]
base_score, coverage_count = _calculate_reconstruction_score(
filtered_contributors, q_method_count
)
coverage_ratio = coverage_count / float(q_method_count)
contributors = sorted(slot["contributors"], key=lambda x: float(x["score"]), reverse=True)[:5]
rows.append(
{
"relative_path": slot["relative_path"],
"class_name": slot["class_name"],
"identifier": f"{slot['relative_path']}:{slot['class_name']}",
"score": base_score,
"coverage": coverage_count,
"coverage_pct": coverage_ratio,
"top_contributors": contributors,
}
)
rows.sort(key=lambda row: (float(row["score"]), int(row["coverage"])), reverse=True)
by_class_out[qcls] = rows[:10]
return by_class_out
def build_overall(result_map: dict[str, dict[str, object]]) -> list[dict[str, object]]:
per_query_best: list[dict[str, float]] = []
for query_identifier, data in result_map.items():
if query_identifier.split(".")[-1] in BOILERPLATE_NAMES:
continue
best_by_file: dict[str, float] = {}
for match in data.get("embedding", []):
rel = match.get("relative_path")
score = match.get("score")
if rel is None or score is None:
continue
best_by_file[rel] = max(best_by_file.get(rel, -1.0), float(score))
per_query_best.append(best_by_file)
aggregate: dict[str, dict[str, object]] = {}
for best_by_file in per_query_best:
for rel, score in best_by_file.items():
slot = aggregate.setdefault(rel, {"relative_path": rel, "scores": []})
slot["scores"].append(score)
overall = []
for rel, slot in aggregate.items():
scores = sorted(slot["scores"], reverse=True)
overall.append(
{
"relative_path": rel,
"score": float(sum(scores) / max(1, len(scores))),
"count": len(scores),
"best_score": float(scores[0]) if scores else -1.0,
}
)
overall.sort(key=lambda item: float(item["score"]), reverse=True)
return overall
by_class_out = build_by_class(output)
by_class_all = build_by_class(output_all)
overall = build_overall(output)
overall_all = build_overall(output_all)
if progress:
progress(total_steps, total_steps, "done")
return {
"results": output,
"overall": overall,
"by_class": by_class_out,
"identical_filtered": identical_filtered,
"by_class_all": by_class_all,
"overall_all": overall_all,
}
def index_status(self) -> dict[str, object]:
return {
"requested_granularity": self.requested_granularity,
"resolved_granularity": self.index_granularity,
"precision": self.precision,
"hub_dataset": self.hub_dataset,
"index_dir": str(self.index_dir) if self.index_dir else None,
"index_origin": self.index_origin,
"missing_files": list(self.missing_files),
"embedding_model": EMBEDDING_MODEL,
}
@staticmethod
def list_models() -> list[str]:
models = []
for path in MODELS_ROOT.iterdir():
if path.is_dir() and not path.name.startswith("__"):
models.append(path.name)
return sorted(models)
def generate_agent_report(analysis_results: dict) -> str:
lines = [
"### MODULAR RECONSTRUCTION SPECIFICATION (FOR AGENT)",
"You are operating inside the Transformers repository. Your task is to produce a `modular_*.py` file that",
"maximizes reuse of existing Transformers components.",
"Purpose: the `modular_model_converter.py` script reads the modular file, resolves inheritance, and",
"generates the final `modeling_*.py` file. The modular file should prefer inheriting from existing",
"Transformers classes (e.g., attention, MLP, blocks, model heads) rather than starting fresh from",
"`nn.Module`, so the generated model shares as much code as possible and only overrides the minimal deltas.",
"Use this report to choose those base classes and to decide where custom logic is actually required.",
"",
]
overall = analysis_results.get("overall", [])
if overall:
lines.append(f"Top architectural base: `{overall[0]['relative_path']}`")
lines.append("\n#### 1. REUSE PLAN (INHERITANCE)")
by_class = analysis_results.get("by_class", {})
if not by_class:
lines.append("- (none)")
for q_cls, matches in by_class.items():
if not matches:
continue
top = matches[0]
if float(top.get("score", 0.0)) >= 0.85:
lines.append(
f"- **{q_cls}**: inherit from `{top['class_name']}` in `{top['relative_path']}`"
)
lines.append(
f" Evidence: {top['coverage']} methods matched (~{top['score']:.2f} similarity)."
)
else:
lines.append(f"- **{q_cls}**: custom implementation (no strong base class match).")
lines.append("\n#### 2. DATAFLOW / STRUCTURAL ANCHORS")
results = analysis_results.get("results", {})
if not results:
lines.append("- (none)")
for query_name, entry in results.items():
if query_name.split(".")[-1] in BOILERPLATE_NAMES:
continue
matches = entry.get("embedding", [])
if matches and float(matches[0].get("score", 0.0)) >= 0.92:
top = matches[0]
lines.append(
f"- `{query_name}` is near-identical to `{top['match_name']}` in `{top['relative_path']}`."
)
lines.append(" Read both methods and reuse as much as possible through modular inheritance.")
lines.append("\n#### 3. PROMPT STARTER")
lines.append("```text")
lines.append("I am building a modular transformers file. Based on similarity analysis:")
lines.append(f"Base model: {overall[0]['relative_path'] if overall else 'Unknown'}")
lines.append("Implement the modular classes using # Copied from tags where possible.")
lines.append("```")
return "\n".join(lines)
def get_default_hub_dataset() -> str:
return os.getenv("HUB_DATASET", HUB_DATASET_DEFAULT)