| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | """ |
| | Modular model detector: utilities for detecting code similarities between model implementations. |
| | |
| | This module provides tools to analyze and detect similarities between different model implementations |
| | in the transformers library. It uses both embedding-based and token-based (Jaccard) similarity metrics |
| | to identify similar code patterns across different model definitions. |
| | |
| | Its function is to identify which models can be _modular_-ized, meaning, which already existing classes are |
| | present in the codebase and look very similar to the one we have. |
| | |
| | Two scores are computed, one is a code embedding, and the other is a simple Jaccard bag-of-tokens index for overlap |
| | of token sets. A score of 1.00 means the code is identical. |
| | |
| | Usage: |
| | |
| | ```bash |
| | cd transformers |
| | |
| | # Use directly the util, it will download the index embedding from the hub. It will require some RAM/VRAM. |
| | |
| | >>> python utils/modular_model_detector.py --modeling-file my_new_beit3_modeling_file.py |
| | Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.62it/s] |
| | encoding 21 query definitions with Qwen/Qwen3-Embedding-4B (device=cuda, batch=16, max_length=4096) |
| | stuff.py::Beit3ImageTextMatchingOutput: |
| | embedding: |
| | blip_2::Blip2ImageTextMatchingModelOutput (0.9994) |
| | chinese_clip::ChineseCLIPOutput (0.9818) |
| | owlvit::OwlViTOutput (0.9818) |
| | aimv2::Aimv2Output (0.9818) |
| | blip::BlipOutput (0.9818) |
| | jaccard: |
| | owlv2::Owlv2Output (0.9667) |
| | metaclip_2::MetaClip2Output (0.9667) |
| | altclip::AltCLIPOutput (0.9667) |
| | owlvit::OwlViTOutput (0.9667) |
| | blip::BlipOutput (0.9667) |
| | intersection: |
| | blip::BlipOutput |
| | owlvit::OwlViTOutput |
| | |
| | stuff.py::Beit3MLP: |
| | embedding: |
| | efficientloftr::EfficientLoFTRMLP (0.9718) |
| | seggpt::SegGptMlp (0.9650) |
| | mgp_str::MgpstrMlp (0.9646) |
| | vitpose_backbone::VitPoseBackboneMLP (0.9640) |
| | granitemoeshared::GraniteMoeSharedMLP (0.9633) |
| | jaccard: |
| | chinese_clip::ChineseCLIPTextSelfOutput (0.5294) |
| | convbert::ConvBertSelfOutput (0.5294) |
| | bert::BertSelfOutput (0.5294) |
| | roformer::RoFormerSelfOutput (0.5294) |
| | layoutlmv3::LayoutLMv3SelfOutput (0.5294) |
| | intersection: |
| | |
| | stuff.py::Beit3FeedForwardNetwork: |
| | embedding: |
| | prophetnet::ProphetNetFeedForward (0.9766) |
| | dab_detr::DabDetrDecoderLayerFFN (0.9730) |
| | kosmos2::Kosmos2TextFFN (0.9697) |
| | kosmos2_5::Kosmos2_5TextFFN (0.9697) |
| | parakeet::ParakeetEncoderFeedForward (0.9678) |
| | jaccard: |
| | groupvit::GroupViTMLP (0.4898) |
| | convbert::ConvBertOutput (0.4600) |
| | chinese_clip::ChineseCLIPTextOutput (0.4565) |
| | bert::BertOutput (0.4565) |
| | roformer::RoFormerOutput (0.4565) |
| | intersection: |
| | |
| | |
| | |
| | ``` |
| | |
| | |
| | # If you wish to build the index first, you can run |
| | |
| | python utils/modular_model_detector.py --build |
| | |
| | # You can also change the embedding model for a larger/smaller one. |
| | """ |
| |
|
| | import argparse |
| | import ast |
| | import json |
| | import logging |
| | import os |
| | import re |
| | from datetime import datetime |
| | from functools import cache |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | from huggingface_hub import HfApi, snapshot_download |
| | from huggingface_hub import logging as huggingface_hub_logging |
| | from safetensors.numpy import load_file as safetensors_load |
| | from safetensors.numpy import save_file as safetensors_save |
| | from tqdm import tqdm |
| |
|
| | import transformers |
| | from transformers import AutoModel, AutoTokenizer |
| | from transformers.utils import enable_tf32 |
| | from transformers.utils import logging as transformers_logging |
| |
|
| |
|
| | |
| | ANSI_RESET = "\033[0m" |
| | ANSI_BOLD = "\033[1m" |
| | ANSI_HEADER = "\033[1;36m" |
| | ANSI_SECTION = "\033[1;35m" |
| | ANSI_ROW = "\033[0;37m" |
| | ANSI_HIGHLIGHT_TOP = "\033[1;32m" |
| | ANSI_HIGHLIGHT_OLD = "\033[1;33m" |
| | ANSI_HIGHLIGHT_CANDIDATE = "\033[1;34m" |
| |
|
| |
|
| | os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" |
| | os.environ["TRANSFORMERS_VERBOSITY"] = "error" |
| |
|
| | MODELS_ROOT = Path("src/transformers/models") |
| | EMBEDDINGS_PATH = "embeddings.safetensors" |
| | INDEX_MAP_PATH = "code_index_map.json" |
| | TOKENS_PATH = "code_index_tokens.json" |
| | HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings" |
| |
|
| | EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" |
| | BATCH_SIZE = 16 |
| | MAX_LENGTH = 4096 |
| |
|
| |
|
| | def _normalize(string: str | None) -> str: |
| | """ |
| | Normalize a string by removing all non-alphanumeric characters and converting to lowercase. |
| | |
| | Args: |
| | string (`str` or `None`): The string to normalize. |
| | |
| | Returns: |
| | `str`: The normalized string, or empty string if input is None. |
| | """ |
| | return re.sub(r"[^a-z0-9]+", "", string.lower()) if string else "" |
| |
|
| |
|
| | def _strip_source_for_tokens(code: str) -> str: |
| | """ |
| | Strip docstrings, comments, and import statements from source code. |
| | |
| | Args: |
| | code (`str`): The source code to strip. |
| | |
| | Returns: |
| | `str`: The stripped source code. |
| | """ |
| | code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) |
| | code = re.sub(r"#.*", "", code) |
| | return "\n".join(line for line in code.splitlines() if not re.match(r"\s*(from|import)\s+", line)) |
| |
|
| |
|
| | def _tokenize(code: str) -> set[str]: |
| | """ |
| | Extract all Python identifiers from source code. |
| | |
| | Args: |
| | code (`str`): The source code to tokenize. |
| | |
| | Returns: |
| | `set[str]`: A set of all identifiers found in the code. |
| | """ |
| | return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) |
| |
|
| |
|
| | def _leading_symbol_prefix(name: str) -> str: |
| | """ |
| | Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention'). |
| | |
| | Args: |
| | name (`str`): The symbol name to extract prefix from. |
| | |
| | Returns: |
| | `str`: The leading prefix, or empty string if no match. |
| | """ |
| | match = re.match(r"^([A-Z][a-z0-9]+)", name) or re.match(r"^([A-Za-z0-9]+)", name) |
| | return match.group(1) if match else "" |
| |
|
| |
|
| | def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: |
| | """ |
| | Sanitize code for embedding by replacing model-specific identifiers with generic placeholder. |
| | |
| | Args: |
| | code (`str`): The source code to sanitize. |
| | model_hint (`str` or `None`): Hint about the model name (e.g., 'llama'). |
| | symbol_hint (`str` or `None`): Hint about the symbol name (e.g., 'LlamaAttention'). |
| | |
| | Returns: |
| | `str`: The sanitized code with model-specific identifiers replaced by 'Model'. |
| | """ |
| | base = _strip_source_for_tokens(code) |
| | variants = set() |
| | if model_hint: |
| | variants.add(model_hint) |
| | variants.add(model_hint.replace("_", "")) |
| | variants.add(re.sub(r"\d+", "", model_hint)) |
| | if symbol_hint: |
| | prefix = _leading_symbol_prefix(symbol_hint) |
| | if prefix: |
| | variants.add(prefix) |
| | variants.add(prefix.replace("_", "")) |
| | variants.add(re.sub(r"\d+", "", prefix)) |
| | variants |= {variant.lower() for variant in list(variants)} |
| | sanitized = base |
| | for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True): |
| | sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE) |
| | return sanitized |
| |
|
| |
|
| | class CodeSimilarityAnalyzer: |
| | """ |
| | Analyzer for detecting code similarities between model implementations. |
| | |
| | This class uses embedding-based and token-based similarity metrics to identify similar |
| | code patterns across different model definitions in the transformers library. |
| | |
| | Args: |
| | hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index. |
| | """ |
| |
|
| | def __init__(self, hub_dataset: str): |
| | for name in ("huggingface_hub", "httpx", "urllib3", "transformers"): |
| | logging.getLogger(name).setLevel(logging.ERROR) |
| | huggingface_hub_logging.set_verbosity_error() |
| | transformers_logging.set_verbosity_error() |
| | enable_tf32(True) |
| | torch.set_grad_enabled(False) |
| |
|
| | self.models_root = MODELS_ROOT |
| | self.hub_dataset = hub_dataset |
| | self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) |
| | self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() |
| |
|
| | self.device = self.model.device |
| | self.index_dir: Path | None = None |
| |
|
| | |
| |
|
| | def _resolve_index_path(self, filename: str) -> Path: |
| | if self.index_dir is None: |
| | return Path(filename) |
| | return self.index_dir / filename |
| |
|
| | def ensure_local_index(self) -> None: |
| | """Ensure index files are available locally, preferring Hub cache snapshots.""" |
| | if self.index_dir is not None and all( |
| | (self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) |
| | ): |
| | return |
| |
|
| | workspace_dir = Path.cwd() |
| | if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)): |
| | self.index_dir = workspace_dir |
| | return |
| |
|
| | logging.info(f"downloading index from hub cache: {self.hub_dataset}") |
| | snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset") |
| | snapshot_dir = Path(snapshot_path) |
| | missing = [ |
| | fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists() |
| | ] |
| | if missing: |
| | raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing)) |
| | self.index_dir = snapshot_dir |
| |
|
| | def push_index_to_hub(self) -> None: |
| | """Upload index files to the Hub dataset repository.""" |
| | api = HfApi() |
| | api.create_repo(repo_id=self.hub_dataset, repo_type="dataset", exist_ok=True) |
| | for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH): |
| | logging.info(f"pushing {fname} -> {self.hub_dataset}") |
| | api.upload_file( |
| | path_or_fileobj=fname, |
| | path_in_repo=os.path.basename(fname), |
| | repo_id=self.hub_dataset, |
| | repo_type="dataset", |
| | ) |
| |
|
| | |
| |
|
| | def _extract_definitions( |
| | self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None |
| | ) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]: |
| | """ |
| | Extract class and function definitions from a Python file. |
| | |
| | Args: |
| | file_path (`Path`): Path to the Python file to parse. |
| | relative_to (`Path` or `None`): Base path for computing relative identifiers. |
| | model_hint (`str` or `None`): Model name hint for sanitization. |
| | |
| | Returns: |
| | `tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing: |
| | - definitions_raw: Mapping of identifiers to raw source code |
| | - definitions_sanitized: Mapping of identifiers to sanitized source code |
| | - definitions_tokens: Mapping of identifiers to sorted token lists |
| | - definitions_kind: Mapping of identifiers to either "class" or "function" |
| | """ |
| | definitions_raw = {} |
| | definitions_sanitized = {} |
| | definitions_tokens = {} |
| | definitions_kind = {} |
| | source = file_path.read_text(encoding="utf-8") |
| | lines = source.splitlines() |
| | tree = ast.parse(source) |
| | for node in ast.iter_child_nodes(tree): |
| | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): |
| | segment = ast.get_source_segment(source, node) |
| | if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): |
| | start = max(0, node.lineno - 1) |
| | end = node.end_lineno |
| | segment = "\n".join(lines[start:end]) |
| | if segment: |
| | identifier = ( |
| | f"{file_path.relative_to(relative_to)}:{node.name}" |
| | if relative_to |
| | else f"{file_path.name}:{node.name}" |
| | ) |
| | definitions_raw[identifier] = segment |
| | sanitized = _sanitize_for_embedding(segment, model_hint, node.name) |
| | definitions_sanitized[identifier] = sanitized |
| | definitions_tokens[identifier] = sorted(_tokenize(sanitized)) |
| | if isinstance(node, ast.ClassDef): |
| | definitions_kind[identifier] = "class" |
| | else: |
| | definitions_kind[identifier] = "function" |
| | return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind |
| |
|
| | def _infer_model_from_relative_path(self, relative_path: Path) -> str | None: |
| | try: |
| | relative = relative_path.resolve().relative_to(self.models_root.resolve()) |
| | return relative.parts[0] |
| | except Exception: |
| | return None |
| |
|
| | def _infer_query_model_name(self, modeling_file: Path) -> str | None: |
| | model = self._infer_model_from_relative_path(modeling_file) |
| | if model: |
| | return model |
| | stem = modeling_file.stem |
| | if stem.startswith("modeling_") and len(stem) > len("modeling_"): |
| | return stem[len("modeling_") :] |
| | return None |
| |
|
| | def _encode_batch(self, texts: list[str]) -> np.ndarray: |
| | """ |
| | Encode a batch of texts into normalized embeddings. |
| | |
| | Args: |
| | texts (`list[str]`): List of text strings to encode. |
| | |
| | Returns: |
| | `np.ndarray`: Normalized embeddings as a float32 numpy array. |
| | """ |
| | encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") |
| | encoded = {key: value.to(self.device) for key, value in encoded.items()} |
| | with ( |
| | torch.autocast(device_type=self.device.type, dtype=self.dtype) |
| | if self.device.type == "cuda" |
| | else torch.no_grad() |
| | ): |
| | output = self.model(**encoded) |
| | if hasattr(output, "last_hidden_state"): |
| | embeddings = output.last_hidden_state |
| | mask = encoded["attention_mask"].unsqueeze(-1) |
| | embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9) |
| | elif hasattr(output, "pooler_output"): |
| | embeddings = output.pooler_output |
| | else: |
| | embeddings = output[0].mean(dim=1) |
| | embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) |
| | return embeddings.cpu().numpy().astype("float32") |
| |
|
| | def encode(self, texts: list[str]) -> np.ndarray: |
| | """ |
| | Encode a list of texts into embeddings, processing in batches. |
| | |
| | Args: |
| | texts (`list[str]`): List of text strings to encode. |
| | |
| | Returns: |
| | `np.ndarray`: Stacked embeddings for all texts. |
| | """ |
| | output = [] |
| | for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="encode", leave=False): |
| | output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) |
| | if self.device.type == "cuda": |
| | torch.cuda.empty_cache() |
| | return np.vstack(output) if output else np.zeros((0, 0), dtype="float32") |
| |
|
| | |
| |
|
| | def build_index(self) -> None: |
| | """Build the code similarity index from all modeling files and save to disk.""" |
| | logging.info("collecting files") |
| | files = list(self.models_root.rglob("modeling_*.py")) |
| | logging.info(f"parsing {len(files)} files") |
| |
|
| | identifiers = [] |
| | sanitized_sources = [] |
| | tokens_map = {} |
| |
|
| | for file_path in tqdm(files, desc="parse", leave=False): |
| | model_hint = self._infer_model_from_relative_path(file_path) |
| | ( |
| | _, |
| | definitions_sanitized, |
| | definitions_tokens, |
| | _, |
| | ) = self._extract_definitions(file_path, self.models_root, model_hint) |
| | for identifier in definitions_sanitized.keys(): |
| | identifiers.append(identifier) |
| | sanitized_sources.append(definitions_sanitized[identifier]) |
| | tokens_map[identifier] = definitions_tokens[identifier] |
| |
|
| | logging.info( |
| | f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" |
| | ) |
| | embeddings = self.encode(sanitized_sources) |
| | safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) |
| | with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file: |
| | json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) |
| | with open(TOKENS_PATH, "w", encoding="utf-8") as file: |
| | json.dump(tokens_map, file) |
| |
|
| | self.index_dir = Path.cwd() |
| |
|
| | def _topk_embedding( |
| | self, |
| | query_embedding_row: np.ndarray, |
| | base_embeddings: np.ndarray, |
| | identifier_map: dict[int, str], |
| | self_model_normalized: str, |
| | self_name: str, |
| | k: int, |
| | ) -> list[tuple[str, float]]: |
| | similarities = query_embedding_row @ base_embeddings.T |
| | indices = np.argpartition(-similarities, k + 32)[: k + 32] |
| | indices = indices[np.argsort(-similarities[indices])] |
| | output = [] |
| | for match_id in indices: |
| | identifier = identifier_map[int(match_id)] |
| | parent_relative_path, match_name = identifier.split(":", 1) |
| | parent_model = Path(parent_relative_path).parts[0] |
| | if match_name == self_name: |
| | continue |
| | if self_model_normalized and _normalize(parent_model) == self_model_normalized: |
| | continue |
| | output.append((identifier, float(similarities[match_id]))) |
| | if len(output) >= k: |
| | break |
| | return output |
| |
|
| | def _topk_jaccard( |
| | self, |
| | query_tokens: set[str], |
| | identifiers: list[str], |
| | tokens_map: dict[str, list[str]], |
| | self_model_normalized: str, |
| | self_name: str, |
| | k: int, |
| | ) -> list[tuple[str, float]]: |
| | """ |
| | Find top-k most similar definitions using Jaccard similarity on token sets. |
| | |
| | Args: |
| | query_tokens (`set[str]`): Set of tokens from the query definition. |
| | identifiers (`list[str]`): List of all definition identifiers in the index. |
| | tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists. |
| | self_model_normalized (`str`): Normalized name of the query model to exclude. |
| | self_name (`str`): Name of the query definition to exclude. |
| | k (`int`): Number of top results to return. |
| | |
| | Returns: |
| | `list[tuple[str, float]]`: List of (identifier, score) tuples. |
| | """ |
| | scores = [] |
| | for identifier in identifiers: |
| | parent_relative_path, match_name = identifier.split(":", 1) |
| | parent_model = Path(parent_relative_path).parts[0] |
| | if match_name == self_name: |
| | continue |
| | if self_model_normalized and _normalize(parent_model) == self_model_normalized: |
| | continue |
| | tokens = set(tokens_map.get(identifier, [])) |
| | if not tokens or not query_tokens: |
| | continue |
| | score = len(query_tokens & tokens) / len(query_tokens | tokens) |
| | if score > 0: |
| | scores.append((identifier, score)) |
| | scores.sort(key=lambda x: x[1], reverse=True) |
| | return scores[:k] |
| |
|
| | def analyze_file( |
| | self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False |
| | ) -> dict[str, dict[str, list]]: |
| | """ |
| | Analyze a modeling file and find similar code definitions in the index. |
| | |
| | Args: |
| | modeling_file (`Path`): Path to the modeling file to analyze. |
| | top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition. |
| | allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally. |
| | |
| | Returns: |
| | `dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results. |
| | Each result contains 'embedding', 'jaccard', and 'intersection' keys. |
| | """ |
| | if allow_hub_fallback: |
| | self.ensure_local_index() |
| |
|
| | base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH))) |
| | base_embeddings = base["embeddings"] |
| | with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file: |
| | identifier_map = {int(key): value for key, value in json.load(file).items()} |
| | identifiers = [identifier_map[i] for i in range(len(identifier_map))] |
| | with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file: |
| | tokens_map = json.load(file) |
| |
|
| | self_model = self._infer_query_model_name(modeling_file) |
| | definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions( |
| | modeling_file, None, self_model |
| | ) |
| | query_identifiers = list(definitions_raw.keys()) |
| | query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers] |
| | query_tokens_list = [set(_tokenize(source)) for source in query_sources_sanitized] |
| | self_model_normalized = _normalize(self_model) |
| |
|
| | logging.info( |
| | f"encoding {len(query_sources_sanitized)} query definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" |
| | ) |
| | query_embeddings = self.encode(query_sources_sanitized) |
| |
|
| | output = {} |
| | for i, query_identifier in enumerate(query_identifiers): |
| | query_name = query_identifier.split(":")[-1] |
| | embedding_top = self._topk_embedding( |
| | query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item |
| | ) |
| | embedding_set = {identifier for identifier, _ in embedding_top} |
| | kind = definitions_kind.get(query_identifier, "function") |
| | entry = {"kind": kind, "embedding": embedding_top} |
| | if use_jaccard: |
| | jaccard_top = self._topk_jaccard( |
| | query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item |
| | ) |
| | jaccard_set = {identifier for identifier, _ in jaccard_top} |
| | intersection = set(embedding_set & jaccard_set) |
| |
|
| | entry.update({"jaccard": jaccard_top, "intersection": intersection}) |
| | output[query_name] = entry |
| | return output |
| |
|
| |
|
| | _RELEASE_RE = re.compile( |
| | r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE |
| | ) |
| |
|
| |
|
| | def build_date_data() -> dict[str, str]: |
| | """ |
| | Scan Markdown files in `root_dir` and build {model_id: date_released}. |
| | |
| | - model_id is the filename without extension (e.g., "llama" for "llama.md") |
| | - date_released is the first YYYY-MM-DD matched after "...was released on ..." |
| | - Ignores non-*.md files and directories. |
| | |
| | Returns: |
| | dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). |
| | Files without a match are simply omitted. |
| | """ |
| |
|
| | root_dir = transformers.__file__.split("src/transformers")[0] |
| | root = Path(root_dir).joinpath("docs/source/en/model_doc") |
| | result: dict[str, str] = {} |
| |
|
| | for md_path in root.glob("*.md"): |
| | try: |
| | text = md_path.read_text(encoding="utf-8", errors="ignore") |
| | except Exception: |
| | |
| | logging.info(f"Failed to read md for {md_path}") |
| |
|
| | m = _RELEASE_RE.search(text) |
| | if m: |
| | model_id = md_path.stem |
| | result[model_id] = m.group(1) |
| |
|
| | return result |
| |
|
| |
|
| | def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str: |
| | if not rows: |
| | return f"{ANSI_ROW}(no matches){ANSI_RESET}" |
| |
|
| | widths = [len(header) for header in headers] |
| | for row in rows: |
| | if row is None: |
| | continue |
| | for idx, cell in enumerate(row): |
| | widths[idx] = max(widths[idx], len(cell)) |
| |
|
| | header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) |
| | divider = "-+-".join("-" * widths[idx] for idx in range(len(headers))) |
| | total_width = sum(widths) + 3 * (len(headers) - 1) |
| |
|
| | styled_rows = [] |
| | style_idx = 0 |
| | for row in rows: |
| | if row is None: |
| | styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}") |
| | continue |
| |
|
| | line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row)) |
| | style = ANSI_ROW |
| | if row_styles and style_idx < len(row_styles) and row_styles[style_idx]: |
| | style = row_styles[style_idx] |
| | styled_rows.append(f"{style}{line}{ANSI_RESET}") |
| | style_idx += 1 |
| |
|
| | return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows) |
| |
|
| |
|
| | def _parse_release_date(value: str) -> datetime | None: |
| | """Return a datetime parsed from YYYY-MM-DD strings, otherwise None.""" |
| | try: |
| | return datetime.strptime(value, "%Y-%m-%d") |
| | except (TypeError, ValueError): |
| | return None |
| |
|
| |
|
| | @cache |
| | def _load_definition_line_map(relative_path: str) -> dict[str, int]: |
| | """Return {definition_name: line_number} for top-level definitions in the given file.""" |
| | file_path = MODELS_ROOT / relative_path |
| | try: |
| | source = file_path.read_text(encoding="utf-8") |
| | except (FileNotFoundError, OSError): |
| | return {} |
| |
|
| | try: |
| | tree = ast.parse(source) |
| | except SyntaxError: |
| | return {} |
| |
|
| | line_map: dict[str, int] = {} |
| | for node in ast.iter_child_nodes(tree): |
| | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): |
| | line_map[node.name] = getattr(node, "lineno", None) or 1 |
| | elif isinstance(node, ast.Assign): |
| | continue |
| | return line_map |
| |
|
| |
|
| | def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]: |
| | """Return full path and formatted line number string for the given definition.""" |
| | full_path = MODELS_ROOT / relative_path |
| | line = _load_definition_line_map(relative_path).get(definition) |
| | line_str = str(line) if line is not None else "?" |
| | return str(full_path), line_str |
| |
|
| |
|
| | def _colorize_heading(text: str) -> str: |
| | return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" |
| |
|
| |
|
| | def main(): |
| | """CLI entry point for the modular model detector.""" |
| | logging.basicConfig(level=logging.INFO, format="%(message)s") |
| | parser = argparse.ArgumentParser(prog="hf-code-sim") |
| | parser.add_argument("--build", action="store_true") |
| | parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.') |
| | parser.add_argument( |
| | "--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset." |
| | ) |
| | parser.add_argument( |
| | "--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index." |
| | ) |
| | parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index") |
| | args = parser.parse_args() |
| |
|
| | analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) |
| |
|
| | if args.build: |
| | analyzer.build_index() |
| | if args.push_new_index: |
| | analyzer.push_index_to_hub() |
| | return |
| |
|
| | if not args.modeling_file: |
| | raise SystemExit("Provide --modeling-file or use --build") |
| |
|
| | dates = build_date_data() |
| | modeling_file = args.modeling_file |
| | if os.sep not in modeling_file: |
| | modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") |
| |
|
| | results = analyzer.analyze_file( |
| | Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard |
| | ) |
| | modeling_filename = Path(modeling_file).name |
| | release_key = modeling_filename.split("modeling_")[-1][:-3] |
| | release_date = dates.get(release_key, "unknown release date") |
| |
|
| | aggregate_scores: dict[str, float] = {} |
| | for data in results.values(): |
| | for identifier, score in data.get("embedding", []): |
| | try: |
| | relative_path, _ = identifier.split(":", 1) |
| | except ValueError: |
| | continue |
| | aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score |
| |
|
| | best_candidate_path: str | None = None |
| | if aggregate_scores: |
| | best_candidate_path = max(aggregate_scores.items(), key=lambda item: item[1])[0] |
| | best_model = Path(best_candidate_path).parts[0] if Path(best_candidate_path).parts else "?" |
| | best_release = dates.get(best_model, "unknown release date") |
| | logging.info( |
| | f"{ANSI_HIGHLIGHT_CANDIDATE}Closest overall candidate: {MODELS_ROOT / best_candidate_path}" |
| | f" (release: {best_release}, total score: {aggregate_scores[best_candidate_path]:.4f}){ANSI_RESET}" |
| | ) |
| |
|
| | grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} |
| | for query_name, data in results.items(): |
| | kind = data.get("kind", "function") |
| | grouped.setdefault(kind, []).append((query_name, data)) |
| |
|
| | section_titles = [("class", "Classes"), ("function", "Functions")] |
| | legend_shown = False |
| | for kind, title in section_titles: |
| | entries = grouped.get(kind, []) |
| | if not entries: |
| | continue |
| |
|
| | metrics_present: set[str] = set() |
| | for _, data in entries: |
| | if data.get("embedding"): |
| | metrics_present.add("embedding") |
| | if args.use_jaccard: |
| | if data.get("jaccard"): |
| | metrics_present.add("jaccard") |
| | if data.get("intersection"): |
| | metrics_present.add("intersection") |
| |
|
| | include_metric_column = bool(metrics_present - {"embedding"}) |
| | headers = ["Symbol", "Path", "Score", "Release"] |
| | if include_metric_column: |
| | headers = ["Symbol", "Metric", "Path", "Score", "Release"] |
| |
|
| | table_rows: list[tuple[str, ...] | None] = [] |
| | row_styles: list[str] = [] |
| | has_metric_rows = False |
| |
|
| | logging.info(_colorize_heading(title)) |
| |
|
| | for query_name, data in entries: |
| | if table_rows: |
| | table_rows.append(None) |
| |
|
| | symbol_label = query_name |
| | if release_date: |
| | symbol_label = f"{symbol_label}" |
| |
|
| | symbol_row = (symbol_label,) + ("",) * (len(headers) - 1) |
| | table_rows.append(symbol_row) |
| | row_styles.append(ANSI_BOLD) |
| |
|
| | embedding_details: list[tuple[str, str, str, float, str]] = [] |
| | embedding_style_indices: list[int] = [] |
| |
|
| | for identifier, score in data.get("embedding", []): |
| | try: |
| | relative_path, match_name = identifier.split(":", 1) |
| | except ValueError: |
| | continue |
| | model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" |
| | match_release = dates.get(model_id, "unknown release date") |
| | full_path, line = _resolve_definition_location(relative_path, match_name) |
| | display_path = f"{full_path}:{line} ({match_name})" |
| |
|
| | if include_metric_column: |
| | row = ("", "embedding", display_path, f"{score:.4f}", match_release) |
| | else: |
| | row = ("", display_path, f"{score:.4f}", match_release) |
| |
|
| | table_rows.append(row) |
| | row_styles.append(ANSI_ROW) |
| | embedding_style_indices.append(len(row_styles) - 1) |
| | embedding_details.append((relative_path, model_id, match_name, score, match_release)) |
| | has_metric_rows = True |
| |
|
| | if embedding_details: |
| | highest_score = None |
| | highest_idx = None |
| | for idx, (_, _, _, score, _) in enumerate(embedding_details): |
| | if highest_score is None or score > highest_score: |
| | highest_score = score |
| | highest_idx = idx |
| |
|
| | if highest_idx is not None: |
| | row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP |
| |
|
| | if highest_score is not None: |
| | oldest_idx = None |
| | oldest_date = None |
| | for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): |
| | if highest_score - score > 0.1: |
| | continue |
| | parsed = _parse_release_date(release_value) |
| | if parsed is None: |
| | continue |
| | if oldest_date is None or parsed < oldest_date: |
| | oldest_date = parsed |
| | oldest_idx = idx |
| | if ( |
| | oldest_idx is not None |
| | and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP |
| | ): |
| | row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD |
| |
|
| | if best_candidate_path is not None: |
| | for idx, (relative_path, _, _, _, _) in enumerate(embedding_details): |
| | style_position = embedding_style_indices[idx] |
| | if row_styles[style_position] != ANSI_ROW: |
| | continue |
| | if relative_path == best_candidate_path: |
| | row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE |
| |
|
| | if args.use_jaccard: |
| | for identifier, score in data.get("jaccard", []): |
| | try: |
| | relative_path, match_name = identifier.split(":", 1) |
| | except ValueError: |
| | continue |
| | model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" |
| | match_release = dates.get(model_id, "unknown release date") |
| | full_path, line = _resolve_definition_location(relative_path, match_name) |
| | display_path = f"{full_path}:{line} ({match_name})" |
| |
|
| | if include_metric_column: |
| | row = ("", "jaccard", display_path, f"{score:.4f}", match_release) |
| | else: |
| | row = ("", display_path, f"{score:.4f}", match_release) |
| |
|
| | table_rows.append(row) |
| | row_styles.append(ANSI_ROW) |
| | has_metric_rows = True |
| | if best_candidate_path == relative_path: |
| | row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE |
| |
|
| | for identifier in sorted(data.get("intersection", [])): |
| | try: |
| | relative_path, match_name = identifier.split(":", 1) |
| | except ValueError: |
| | continue |
| | model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" |
| | match_release = dates.get(model_id, "unknown release date") |
| | full_path, line = _resolve_definition_location(relative_path, match_name) |
| | display_path = f"{full_path}:{line} ({match_name})" |
| |
|
| | if include_metric_column: |
| | row = ("", "intersection", display_path, "--", match_release) |
| | else: |
| | row = ("", display_path, "--", match_release) |
| |
|
| | table_rows.append(row) |
| | row_styles.append(ANSI_ROW) |
| | has_metric_rows = True |
| | if best_candidate_path == relative_path: |
| | row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE |
| |
|
| | if table_rows: |
| | if not legend_shown and has_metric_rows: |
| | logging.info( |
| | "Legend: " |
| | f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, " |
| | f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, " |
| | f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}" |
| | ) |
| | legend_shown = True |
| |
|
| | logging.info(_format_table(headers, table_rows, row_styles)) |
| | logging.info("") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|