| |
| """Encode caption text and compute block Vendi scores. |
| |
| The script is intentionally split into three subcommands: |
| - `inspect`: report tokenizer/config limits for candidate encoders |
| - `encode`: cache normalized text embeddings from JSONL captions |
| - `vendi`: compute sampled block Vendi/effective-rank summaries from caches |
| |
| The encoder path is GPU-ready but the same code can be sanity-checked on CPU |
| with a tiny sample before H200 allocation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import random |
| import sys |
| import time |
| import types |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Any, Iterable |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| @dataclass |
| class EmbeddingShard: |
| path: str |
| rows: int |
| dim: int |
| dtype: str |
| start_row: int |
| end_row: int |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Caption embedding cache and Vendi utilities") |
| subparsers = parser.add_subparsers(dest="cmd", required=True) |
|
|
| inspect = subparsers.add_parser("inspect", help="Inspect tokenizer/model text limits") |
| inspect.add_argument("--model", action="append", required=True, help="HF model id/path; may be repeated") |
| inspect.add_argument("--trust-remote-code", action="store_true") |
| inspect.add_argument( |
| "--compat-remote-code", |
| action="store_true", |
| help="Install small compatibility shims for older HF remote-code embedding models.", |
| ) |
|
|
| encode = subparsers.add_parser("encode", help="Extract normalized text embeddings") |
| encode.add_argument("--input", required=True, help="JSONL input") |
| encode.add_argument("--text-field", default="caption") |
| encode.add_argument("--id-field", default=None) |
| encode.add_argument("--model", required=True) |
| encode.add_argument("--output-dir", required=True) |
| encode.add_argument("--max-records", type=int, default=None) |
| encode.add_argument( |
| "--sample-records", |
| type=int, |
| default=None, |
| help="Reservoir-sample this many records before modulo splitting. Mutually exclusive with --max-records.", |
| ) |
| encode.add_argument("--sample-seed", type=int, default=0) |
| encode.add_argument("--split-count", type=int, default=1, help="Modulo split count for multi-GPU extraction") |
| encode.add_argument("--split-index", type=int, default=0, help="Modulo split index for this worker") |
| encode.add_argument("--batch-size", type=int, default=256) |
| encode.add_argument("--max-length", type=int, default=None) |
| encode.add_argument("--device", default="cuda") |
| encode.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"]) |
| encode.add_argument("--embedding-dtype", default="float16", choices=["float16", "float32"]) |
| encode.add_argument("--shard-rows", type=int, default=100_000) |
| encode.add_argument("--pooling", default="auto", choices=["auto", "cls", "mean", "pooler", "last"]) |
| encode.add_argument("--padding-side", default=None, choices=["left", "right"], help="Override tokenizer padding side") |
| encode.add_argument("--text-prefix", default="", help="Prefix applied to every text before tokenization") |
| encode.add_argument( |
| "--text-template", |
| default=None, |
| help="Python format template applied before tokenization. Must contain '{text}'. Overrides --text-prefix.", |
| ) |
| encode.add_argument("--trust-remote-code", action="store_true") |
| encode.add_argument( |
| "--compat-remote-code", |
| action="store_true", |
| help="Install small compatibility shims for older HF remote-code embedding models.", |
| ) |
| encode.add_argument("--compile", action="store_true") |
|
|
| bge = subparsers.add_parser("encode-bge-m3", help="Extract official BGE-M3 dense embeddings via FlagEmbedding") |
| bge.add_argument("--input", required=True, help="JSONL input") |
| bge.add_argument("--text-field", default="caption") |
| bge.add_argument("--id-field", default=None) |
| bge.add_argument("--model", default="BAAI/bge-m3") |
| bge.add_argument("--output-dir", required=True) |
| bge.add_argument("--max-records", type=int, default=None) |
| bge.add_argument("--sample-records", type=int, default=None) |
| bge.add_argument("--sample-seed", type=int, default=0) |
| bge.add_argument("--split-count", type=int, default=1) |
| bge.add_argument("--split-index", type=int, default=0) |
| bge.add_argument("--batch-size", type=int, default=256) |
| bge.add_argument("--max-length", type=int, default=512) |
| bge.add_argument("--device", default="cuda") |
| bge.add_argument("--use-fp16", action=argparse.BooleanOptionalAction, default=True) |
| bge.add_argument("--embedding-dtype", default="float16", choices=["float16", "float32"]) |
| bge.add_argument("--shard-rows", type=int, default=100_000) |
| bge.add_argument("--text-prefix", default="", help="Prefix applied to every text before encoding") |
| bge.add_argument("--text-template", default=None, help="Python format template containing '{text}'") |
| bge.add_argument("--encode-mode", default="corpus", choices=["corpus", "queries", "encode"]) |
| bge.add_argument("--query-instruction", default=None, help="Optional BGEM3 query_instruction_for_retrieval") |
| bge.add_argument("--query-instruction-format", default="{}{}", help="BGEM3 query_instruction_format") |
|
|
| st = subparsers.add_parser( |
| "encode-sentence-transformer", |
| help="Extract embeddings with SentenceTransformer's model-specific encode protocol", |
| ) |
| st.add_argument("--input", required=True, help="JSONL input") |
| st.add_argument("--text-field", default="caption") |
| st.add_argument("--id-field", default=None) |
| st.add_argument("--model", required=True) |
| st.add_argument("--output-dir", required=True) |
| st.add_argument("--max-records", type=int, default=None) |
| st.add_argument("--sample-records", type=int, default=None) |
| st.add_argument("--sample-seed", type=int, default=0) |
| st.add_argument("--split-count", type=int, default=1) |
| st.add_argument("--split-index", type=int, default=0) |
| st.add_argument("--batch-size", type=int, default=256) |
| st.add_argument("--max-length", type=int, default=None) |
| st.add_argument("--device", default="cuda") |
| st.add_argument("--embedding-dtype", default="float16", choices=["float16", "float32"]) |
| st.add_argument("--shard-rows", type=int, default=100_000) |
| st.add_argument("--text-prefix", default="", help="Prefix applied to every text before encoding") |
| st.add_argument("--text-template", default=None, help="Python format template containing '{text}'") |
| st.add_argument("--prompt-name", default=None, help="SentenceTransformer prompt_name, e.g. document or query") |
|
|
| vendi = subparsers.add_parser("vendi", help="Compute sampled block Vendi from embedding cache") |
| vendi.add_argument("--manifest", required=True) |
| vendi.add_argument("--output", required=True) |
| vendi.add_argument("--block-size", type=int, default=4096) |
| vendi.add_argument("--blocks", type=int, default=64) |
| vendi.add_argument( |
| "--sampling", |
| choices=["random", "partition"], |
| default="random", |
| help="random samples blocks; partition shuffles once and uses every row in disjoint blocks.", |
| ) |
| vendi.add_argument("--seed", type=int, default=0) |
| vendi.add_argument("--device", default="cuda") |
| vendi.add_argument("--matrix-device", default=None, help="Override device for eigvalsh; defaults to --device") |
| vendi.add_argument("--dtype", default="float32", choices=["float16", "bfloat16", "float32"]) |
|
|
| geom = subparsers.add_parser("geometry", help="Compute embedding-distribution geometry summaries") |
| geom.add_argument("--manifest", required=True) |
| geom.add_argument("--output", required=True) |
| geom.add_argument("--max-rows", type=int, default=100_000) |
| geom.add_argument("--seed", type=int, default=0) |
| geom.add_argument("--device", default="cuda") |
| geom.add_argument("--dtype", default="float32", choices=["float16", "bfloat16", "float32"]) |
|
|
| knn = subparsers.add_parser("knn", help="Compute exact nearest-neighbor support between two embedding caches") |
| knn.add_argument("--query-manifest", required=True) |
| knn.add_argument("--gallery-manifest", required=True) |
| knn.add_argument("--output", required=True) |
| knn.add_argument("--query-max-rows", type=int, default=None) |
| knn.add_argument("--gallery-max-rows", type=int, default=None) |
| knn.add_argument("--seed", type=int, default=0) |
| knn.add_argument("--device", default="cuda") |
| knn.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"]) |
| knn.add_argument("--query-batch-size", type=int, default=1024) |
| knn.add_argument( |
| "--gallery-chunk-size", |
| type=int, |
| default=0, |
| help="0 keeps the full gallery resident on device; positive values stream gallery chunks.", |
| ) |
| knn.add_argument("--thresholds", default="0.60,0.70,0.75,0.80,0.85,0.90") |
| knn.add_argument("--save-scores", default=None, help="Optional .npy path for per-query nearest-neighbor cosine scores") |
|
|
| support = subparsers.add_parser("support", help="Compute PRDC-style query-in-gallery manifold support") |
| support.add_argument("--query-manifest", required=True, help="Prompt/query embedding manifest P") |
| support.add_argument("--gallery-manifest", required=True, help="Caption/support embedding manifest C") |
| support.add_argument("--output", required=True) |
| support.add_argument("--query-max-rows", type=int, default=None) |
| support.add_argument("--gallery-max-rows", type=int, default=None) |
| support.add_argument("--seed", type=int, default=0) |
| support.add_argument("--k", type=int, default=10) |
| support.add_argument("--device", default="cuda") |
| support.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"]) |
| support.add_argument("--query-batch-size", type=int, default=512) |
| support.add_argument("--gallery-batch-size", type=int, default=512) |
| support.add_argument("--save-scores", default=None, help="Optional .npz path for per-query support scores") |
|
|
| return parser.parse_args() |
|
|
|
|
| def torch_dtype(name: str) -> torch.dtype: |
| return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name] |
|
|
|
|
| def numpy_dtype(name: str) -> np.dtype: |
| return {"float16": np.float16, "float32": np.float32}[name] |
|
|
|
|
| def load_transformers(): |
| try: |
| from transformers import AutoConfig, AutoModel, AutoTokenizer |
| except ImportError as exc: |
| raise SystemExit("transformers is required. Run through `uv run` after sourcing .env.") from exc |
| return AutoConfig, AutoModel, AutoTokenizer |
|
|
|
|
| def install_remote_code_compat() -> None: |
| """Compatibility shims for embedding-model remote code. |
| |
| Jina v2 imports `transformers.onnx.OnnxConfig`, which is absent in the |
| current Transformers build used by this project. Jina v3 also expects the |
| legacy `all_tied_weights_keys` property on PreTrainedModel. The shims are |
| intentionally minimal and only installed when requested. |
| """ |
| try: |
| import transformers |
| from transformers import PreTrainedModel |
| except ImportError: |
| return |
|
|
| if "transformers.onnx" not in sys.modules: |
| onnx_module = types.ModuleType("transformers.onnx") |
|
|
| class OnnxConfig: |
| pass |
|
|
| onnx_module.OnnxConfig = OnnxConfig |
| sys.modules["transformers.onnx"] = onnx_module |
| setattr(transformers, "onnx", onnx_module) |
|
|
| if not hasattr(PreTrainedModel, "all_tied_weights_keys"): |
|
|
| def all_tied_weights_keys(self: Any) -> dict[str, None]: |
| stored = getattr(self, "_compat_all_tied_weights_keys", None) |
| if stored is not None: |
| return stored |
| keys = getattr(self, "_tied_weights_keys", None) or [] |
| return {key: None for key in keys} |
|
|
| def set_all_tied_weights_keys(self: Any, value: Any) -> None: |
| if isinstance(value, dict): |
| self._compat_all_tied_weights_keys = value |
| elif value is None: |
| self._compat_all_tied_weights_keys = {} |
| else: |
| self._compat_all_tied_weights_keys = {key: None for key in value} |
|
|
| PreTrainedModel.all_tied_weights_keys = property( |
| all_tied_weights_keys, |
| set_all_tied_weights_keys, |
| ) |
|
|
| try: |
| import transformers.pytorch_utils as pytorch_utils |
|
|
| if not hasattr(pytorch_utils, "find_pruneable_heads_and_indices"): |
| def find_pruneable_heads_and_indices( |
| heads: list[int] | set[int], |
| n_heads: int, |
| head_size: int, |
| already_pruned_heads: set[int], |
| ) -> tuple[set[int], torch.Tensor]: |
| heads = set(heads) - already_pruned_heads |
| mask = torch.ones(n_heads, head_size) |
| for head in heads: |
| pruned_before = sum(1 if pruned_head < head else 0 for pruned_head in already_pruned_heads) |
| mask[head - pruned_before] = 0 |
| mask = mask.view(-1).contiguous().eq(1) |
| index = torch.arange(len(mask))[mask].long() |
| return heads, index |
|
|
| pytorch_utils.find_pruneable_heads_and_indices = find_pruneable_heads_and_indices |
| if not hasattr(pytorch_utils, "prune_linear_layer"): |
| from transformers.modeling_utils import prune_linear_layer |
|
|
| pytorch_utils.prune_linear_layer = prune_linear_layer |
| except Exception: |
| pass |
|
|
|
|
| def iter_jsonl( |
| path: Path, |
| text_field: str, |
| id_field: str | None, |
| max_records: int | None, |
| split_count: int, |
| split_index: int, |
| ) -> Iterable[tuple[str, str | None, int]]: |
| emitted = 0 |
| seen = 0 |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| if max_records is not None and emitted >= max_records: |
| break |
| line = line.strip() |
| if not line: |
| seen += 1 |
| continue |
| row_index = seen |
| seen += 1 |
| if row_index % split_count != split_index: |
| continue |
| row = json.loads(line) |
| text = row.get(text_field) |
| if not isinstance(text, str): |
| text = "" |
| row_id = str(row.get(id_field)) if id_field and row.get(id_field) is not None else None |
| emitted += 1 |
| yield text, row_id, row_index |
|
|
|
|
| def iter_jsonl_sampled( |
| path: Path, |
| text_field: str, |
| id_field: str | None, |
| sample_records: int, |
| sample_seed: int, |
| split_count: int, |
| split_index: int, |
| ) -> Iterable[tuple[str, str | None, int]]: |
| if sample_records < 1: |
| raise SystemExit("--sample-records must be >= 1") |
| rng = random.Random(sample_seed) |
| reservoir: list[tuple[str, str | None, int]] = [] |
| seen = 0 |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| line = line.strip() |
| if not line: |
| continue |
| row_index = seen |
| seen += 1 |
| row = json.loads(line) |
| text = row.get(text_field) |
| if not isinstance(text, str): |
| text = "" |
| row_id = str(row.get(id_field)) if id_field and row.get(id_field) is not None else None |
| item = (text, row_id, row_index) |
| if len(reservoir) < sample_records: |
| reservoir.append(item) |
| else: |
| replace_index = rng.randrange(seen) |
| if replace_index < sample_records: |
| reservoir[replace_index] = item |
| reservoir.sort(key=lambda item: item[2]) |
| for emitted, item in enumerate(reservoir): |
| if emitted % split_count == split_index: |
| yield item |
|
|
|
|
| def batched(items: Iterable[tuple[str, str | None, int]], batch_size: int) -> Iterable[list[tuple[str, str | None, int]]]: |
| batch: list[tuple[str, str | None, int]] = [] |
| for item in items: |
| batch.append(item) |
| if len(batch) >= batch_size: |
| yield batch |
| batch = [] |
| if batch: |
| yield batch |
|
|
|
|
| def config_text_limit(config: Any) -> int | None: |
| candidates = [] |
| for obj in [config, getattr(config, "text_config", None)]: |
| if obj is None: |
| continue |
| for name in ["max_position_embeddings", "max_sequence_length", "context_length", "seq_length"]: |
| value = getattr(obj, name, None) |
| if isinstance(value, int) and value > 0: |
| candidates.append(value) |
| return min(candidates) if candidates else None |
|
|
|
|
| def inspect_models(args: argparse.Namespace) -> int: |
| if args.compat_remote_code: |
| install_remote_code_compat() |
| AutoConfig, _AutoModel, AutoTokenizer = load_transformers() |
| rows = [] |
| for model_id in args.model: |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=args.trust_remote_code) |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code) |
| rows.append( |
| { |
| "model": model_id, |
| "model_type": getattr(config, "model_type", None), |
| "tokenizer_model_max_length": getattr(tokenizer, "model_max_length", None), |
| "config_text_limit": config_text_limit(config), |
| "text_config_max_position_embeddings": getattr(getattr(config, "text_config", None), "max_position_embeddings", None), |
| "max_position_embeddings": getattr(config, "max_position_embeddings", None), |
| "projection_dim": getattr(config, "projection_dim", None) or getattr(config, "projection_size", None), |
| "hidden_size": getattr(config, "hidden_size", None) or getattr(getattr(config, "text_config", None), "hidden_size", None), |
| } |
| ) |
| print(json.dumps(rows, indent=2, ensure_ascii=False)) |
| return 0 |
|
|
|
|
| def load_encoder( |
| model_id: str, |
| device: str, |
| dtype: str, |
| trust_remote_code: bool, |
| compile_model: bool, |
| compat_remote_code: bool, |
| padding_side: str | None, |
| ): |
| if compat_remote_code: |
| install_remote_code_compat() |
| AutoConfig, AutoModel, AutoTokenizer = load_transformers() |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) |
| if padding_side is not None: |
| tokenizer.padding_side = padding_side |
| config = None |
| if compat_remote_code: |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) |
| for name, value in { |
| "is_decoder": False, |
| "add_cross_attention": False, |
| "chunk_size_feed_forward": 0, |
| "use_return_dict": True, |
| "output_attentions": False, |
| "output_hidden_states": False, |
| }.items(): |
| if not hasattr(config, name): |
| setattr(config, name, value) |
| model = AutoModel.from_pretrained( |
| model_id, |
| config=config, |
| dtype=torch_dtype(dtype), |
| trust_remote_code=trust_remote_code, |
| ) |
| model.eval().to(device) |
| if compile_model: |
| model = torch.compile(model) |
| return tokenizer, model |
|
|
|
|
| def pool_outputs(model: Any, outputs: Any, encoded: dict[str, torch.Tensor], pooling: str) -> torch.Tensor: |
| if hasattr(outputs, "text_embeds") and outputs.text_embeds is not None: |
| return outputs.text_embeds |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None and pooling in {"auto", "pooler"}: |
| return outputs.pooler_output |
| hidden = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0] |
| if pooling == "last": |
| attention = encoded.get("attention_mask") |
| if attention is None: |
| return hidden[:, -1] |
| left_padding = bool((attention[:, -1].sum() == attention.shape[0]).item()) |
| if left_padding: |
| return hidden[:, -1] |
| sequence_lengths = attention.sum(dim=1) - 1 |
| batch_size = hidden.shape[0] |
| return hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths] |
| if pooling == "cls": |
| return hidden[:, 0] |
| attention = encoded.get("attention_mask") |
| if pooling in {"auto", "mean"} and attention is not None: |
| weights = attention.to(hidden.dtype).unsqueeze(-1) |
| return (hidden * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0) |
| return hidden[:, 0] |
|
|
|
|
| @torch.inference_mode() |
| def encode_batch( |
| tokenizer: Any, |
| model: Any, |
| texts: list[str], |
| device: str, |
| max_length: int | None, |
| pooling: str, |
| ) -> torch.Tensor: |
| encoded = tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ) |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| if hasattr(model, "get_text_features"): |
| features = model.get_text_features(**encoded) |
| if not isinstance(features, torch.Tensor): |
| features = pool_outputs(model, features, encoded, pooling) |
| else: |
| outputs = model(**encoded) |
| features = pool_outputs(model, outputs, encoded, pooling) |
| features = torch.nn.functional.normalize(features.float(), dim=-1) |
| return features.cpu() |
|
|
|
|
| def flush_shard( |
| output_dir: Path, |
| shard_index: int, |
| start_row: int, |
| rows: list[np.ndarray], |
| embedding_dtype: str, |
| ) -> EmbeddingShard: |
| array = np.asarray(rows, dtype=numpy_dtype(embedding_dtype)) |
| path = output_dir / f"embeddings-{shard_index:05d}.npy" |
| np.save(path, array) |
| return EmbeddingShard( |
| path=str(path), |
| rows=int(array.shape[0]), |
| dim=int(array.shape[1]) if array.ndim == 2 else 0, |
| dtype=embedding_dtype, |
| start_row=start_row, |
| end_row=start_row + int(array.shape[0]), |
| ) |
|
|
|
|
| def encode_main(args: argparse.Namespace) -> int: |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| tokenizer, model = load_encoder( |
| args.model, |
| args.device, |
| args.dtype, |
| args.trust_remote_code, |
| args.compile, |
| args.compat_remote_code, |
| args.padding_side, |
| ) |
| config_limit = config_text_limit(getattr(model, "config", None)) |
| max_length = args.max_length or config_limit or getattr(tokenizer, "model_max_length", None) |
| if isinstance(max_length, int) and max_length > 1_000_000: |
| max_length = None |
|
|
| rows: list[np.ndarray] = [] |
| row_ids: list[str | None] = [] |
| row_indices: list[int] = [] |
| shards: list[EmbeddingShard] = [] |
| total = 0 |
| shard_start = 0 |
| started = time.time() |
| if args.split_count < 1: |
| raise SystemExit("--split-count must be >= 1") |
| if not (0 <= args.split_index < args.split_count): |
| raise SystemExit("--split-index must satisfy 0 <= split_index < split_count") |
| if args.sample_records is not None and args.max_records is not None: |
| raise SystemExit("--sample-records and --max-records are mutually exclusive") |
| if args.text_template is not None and "{text}" not in args.text_template: |
| raise SystemExit("--text-template must contain '{text}'") |
| if args.sample_records is not None: |
| source = iter_jsonl_sampled( |
| Path(args.input), |
| args.text_field, |
| args.id_field, |
| args.sample_records, |
| args.sample_seed, |
| args.split_count, |
| args.split_index, |
| ) |
| else: |
| source = iter_jsonl( |
| Path(args.input), |
| args.text_field, |
| args.id_field, |
| args.max_records, |
| args.split_count, |
| args.split_index, |
| ) |
| for batch in batched(source, args.batch_size): |
| texts = [text for text, _row_id, _row_index in batch] |
| if args.text_template is not None: |
| texts = [args.text_template.format(text=text) for text in texts] |
| elif args.text_prefix: |
| texts = [f"{args.text_prefix}{text}" for text in texts] |
| ids = [row_id for _text, row_id, _row_index in batch] |
| indices = [row_index for _text, _row_id, row_index in batch] |
| features = encode_batch(tokenizer, model, texts, args.device, max_length, args.pooling) |
| rows.extend(features.numpy()) |
| row_ids.extend(ids) |
| row_indices.extend(indices) |
| total += len(batch) |
| if len(rows) >= args.shard_rows: |
| shards.append(flush_shard(output_dir, len(shards), shard_start, rows, args.embedding_dtype)) |
| shard_start += len(rows) |
| rows = [] |
| if rows: |
| shards.append(flush_shard(output_dir, len(shards), shard_start, rows, args.embedding_dtype)) |
|
|
| if row_indices: |
| with (output_dir / "row_ids.jsonl").open("w", encoding="utf-8") as handle: |
| for index, (row_id, row_index) in enumerate(zip(row_ids, row_indices, strict=True)): |
| handle.write( |
| json.dumps( |
| {"split_row": index, "source_row": row_index, "id": row_id}, |
| ensure_ascii=False, |
| ) |
| + "\n" |
| ) |
|
|
| manifest = { |
| "input": args.input, |
| "text_field": args.text_field, |
| "id_field": args.id_field, |
| "model": args.model, |
| "max_length": max_length, |
| "max_records": args.max_records, |
| "sample_records": args.sample_records, |
| "sample_seed": args.sample_seed, |
| "split_count": args.split_count, |
| "split_index": args.split_index, |
| "pooling": args.pooling, |
| "padding_side": getattr(tokenizer, "padding_side", None), |
| "text_prefix": args.text_prefix, |
| "text_template": args.text_template, |
| "compat_remote_code": args.compat_remote_code, |
| "device": args.device, |
| "dtype": args.dtype, |
| "embedding_dtype": args.embedding_dtype, |
| "rows": total, |
| "seconds": round(time.time() - started, 3), |
| "rows_per_second": round(total / max(time.time() - started, 1e-6), 3), |
| "shards": [asdict(shard) for shard in shards], |
| } |
| (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output_dir": str(output_dir), "rows": total, "shards": len(shards), "max_length": max_length}, indent=2)) |
| return 0 |
|
|
|
|
| def encode_bge_m3_main(args: argparse.Namespace) -> int: |
| try: |
| from FlagEmbedding import BGEM3FlagModel |
| except ImportError as exc: |
| raise SystemExit("FlagEmbedding is required for encode-bge-m3. Install with `uv sync --extra eval`.") from exc |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| if args.split_count < 1: |
| raise SystemExit("--split-count must be >= 1") |
| if not (0 <= args.split_index < args.split_count): |
| raise SystemExit("--split-index must satisfy 0 <= split_index < split_count") |
| if args.sample_records is not None and args.max_records is not None: |
| raise SystemExit("--sample-records and --max-records are mutually exclusive") |
| if args.text_template is not None and "{text}" not in args.text_template: |
| raise SystemExit("--text-template must contain '{text}'") |
|
|
| model = BGEM3FlagModel( |
| args.model, |
| normalize_embeddings=True, |
| use_fp16=args.use_fp16, |
| devices=args.device, |
| pooling_method="cls", |
| batch_size=args.batch_size, |
| query_max_length=args.max_length, |
| passage_max_length=args.max_length, |
| return_dense=True, |
| return_sparse=False, |
| return_colbert_vecs=False, |
| query_instruction_for_retrieval=args.query_instruction, |
| query_instruction_format=args.query_instruction_format, |
| ) |
| if args.sample_records is not None: |
| source = iter_jsonl_sampled( |
| Path(args.input), |
| args.text_field, |
| args.id_field, |
| args.sample_records, |
| args.sample_seed, |
| args.split_count, |
| args.split_index, |
| ) |
| else: |
| source = iter_jsonl( |
| Path(args.input), |
| args.text_field, |
| args.id_field, |
| args.max_records, |
| args.split_count, |
| args.split_index, |
| ) |
|
|
| rows: list[np.ndarray] = [] |
| row_ids: list[str | None] = [] |
| row_indices: list[int] = [] |
| shards: list[EmbeddingShard] = [] |
| total = 0 |
| shard_start = 0 |
| started = time.time() |
| for batch in batched(source, args.batch_size): |
| texts = [text for text, _row_id, _row_index in batch] |
| if args.text_template is not None: |
| texts = [args.text_template.format(text=text) for text in texts] |
| elif args.text_prefix: |
| texts = [f"{args.text_prefix}{text}" for text in texts] |
| ids = [row_id for _text, row_id, _row_index in batch] |
| indices = [row_index for _text, _row_id, row_index in batch] |
| encode_fn = { |
| "corpus": model.encode_corpus, |
| "queries": model.encode_queries, |
| "encode": model.encode, |
| }[args.encode_mode] |
| encoded = encode_fn( |
| texts, |
| batch_size=args.batch_size, |
| max_length=args.max_length, |
| return_dense=True, |
| return_sparse=False, |
| return_colbert_vecs=False, |
| ) |
| features = np.asarray(encoded["dense_vecs"], dtype=np.float32) |
| features /= np.maximum(np.linalg.norm(features, axis=1, keepdims=True), 1e-12) |
| rows.extend(features) |
| row_ids.extend(ids) |
| row_indices.extend(indices) |
| total += len(batch) |
| if len(rows) >= args.shard_rows: |
| shards.append(flush_shard(output_dir, len(shards), shard_start, rows, args.embedding_dtype)) |
| shard_start += len(rows) |
| rows = [] |
| if rows: |
| shards.append(flush_shard(output_dir, len(shards), shard_start, rows, args.embedding_dtype)) |
|
|
| if row_indices: |
| with (output_dir / "row_ids.jsonl").open("w", encoding="utf-8") as handle: |
| for index, (row_id, row_index) in enumerate(zip(row_ids, row_indices, strict=True)): |
| handle.write( |
| json.dumps( |
| {"split_row": index, "source_row": row_index, "id": row_id}, |
| ensure_ascii=False, |
| ) |
| + "\n" |
| ) |
|
|
| elapsed = time.time() - started |
| manifest = { |
| "input": args.input, |
| "text_field": args.text_field, |
| "id_field": args.id_field, |
| "model": args.model, |
| "backend": "FlagEmbedding.BGEM3FlagModel", |
| "max_length": args.max_length, |
| "max_records": args.max_records, |
| "sample_records": args.sample_records, |
| "sample_seed": args.sample_seed, |
| "split_count": args.split_count, |
| "split_index": args.split_index, |
| "pooling": "cls", |
| "encode_mode": args.encode_mode, |
| "normalize_embeddings": True, |
| "text_prefix": args.text_prefix, |
| "text_template": args.text_template, |
| "query_instruction": args.query_instruction, |
| "query_instruction_format": args.query_instruction_format, |
| "device": args.device, |
| "use_fp16": args.use_fp16, |
| "embedding_dtype": args.embedding_dtype, |
| "rows": total, |
| "seconds": round(elapsed, 3), |
| "rows_per_second": round(total / max(elapsed, 1e-6), 3), |
| "shards": [asdict(shard) for shard in shards], |
| } |
| (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output_dir": str(output_dir), "rows": total, "shards": len(shards), "max_length": args.max_length}, indent=2)) |
| return 0 |
|
|
|
|
| def encode_sentence_transformer_main(args: argparse.Namespace) -> int: |
| try: |
| from sentence_transformers import SentenceTransformer |
| except ImportError as exc: |
| raise SystemExit("sentence-transformers is required. Run `uv sync --extra eval`.") from exc |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| if args.split_count < 1: |
| raise SystemExit("--split-count must be >= 1") |
| if not (0 <= args.split_index < args.split_count): |
| raise SystemExit("--split-index must satisfy 0 <= split_index < split_count") |
| if args.sample_records is not None and args.max_records is not None: |
| raise SystemExit("--sample-records and --max-records are mutually exclusive") |
| if args.text_template is not None and "{text}" not in args.text_template: |
| raise SystemExit("--text-template must contain '{text}'") |
|
|
| model = SentenceTransformer(args.model, device=args.device) |
| if args.max_length is not None: |
| model.max_seq_length = args.max_length |
| max_length = int(model.max_seq_length) if getattr(model, "max_seq_length", None) is not None else args.max_length |
| if args.sample_records is not None: |
| source = iter_jsonl_sampled( |
| Path(args.input), |
| args.text_field, |
| args.id_field, |
| args.sample_records, |
| args.sample_seed, |
| args.split_count, |
| args.split_index, |
| ) |
| else: |
| source = iter_jsonl( |
| Path(args.input), |
| args.text_field, |
| args.id_field, |
| args.max_records, |
| args.split_count, |
| args.split_index, |
| ) |
|
|
| rows: list[np.ndarray] = [] |
| row_ids: list[str | None] = [] |
| row_indices: list[int] = [] |
| shards: list[EmbeddingShard] = [] |
| total = 0 |
| shard_start = 0 |
| started = time.time() |
| for batch in batched(source, args.batch_size): |
| texts = [text for text, _row_id, _row_index in batch] |
| if args.text_template is not None: |
| texts = [args.text_template.format(text=text) for text in texts] |
| elif args.text_prefix: |
| texts = [f"{args.text_prefix}{text}" for text in texts] |
| ids = [row_id for _text, row_id, _row_index in batch] |
| indices = [row_index for _text, _row_id, row_index in batch] |
| encode_kwargs = { |
| "batch_size": args.batch_size, |
| "normalize_embeddings": True, |
| "convert_to_numpy": True, |
| "show_progress_bar": False, |
| } |
| if args.prompt_name is not None: |
| encode_kwargs["prompt_name"] = args.prompt_name |
| features = model.encode(texts, **encode_kwargs) |
| features = np.asarray(features, dtype=np.float32) |
| features /= np.maximum(np.linalg.norm(features, axis=1, keepdims=True), 1e-12) |
| rows.extend(features) |
| row_ids.extend(ids) |
| row_indices.extend(indices) |
| total += len(batch) |
| if len(rows) >= args.shard_rows: |
| shards.append(flush_shard(output_dir, len(shards), shard_start, rows, args.embedding_dtype)) |
| shard_start += len(rows) |
| rows = [] |
| if rows: |
| shards.append(flush_shard(output_dir, len(shards), shard_start, rows, args.embedding_dtype)) |
|
|
| if row_indices: |
| with (output_dir / "row_ids.jsonl").open("w", encoding="utf-8") as handle: |
| for index, (row_id, row_index) in enumerate(zip(row_ids, row_indices, strict=True)): |
| handle.write( |
| json.dumps( |
| {"split_row": index, "source_row": row_index, "id": row_id}, |
| ensure_ascii=False, |
| ) |
| + "\n" |
| ) |
|
|
| elapsed = time.time() - started |
| manifest = { |
| "input": args.input, |
| "text_field": args.text_field, |
| "id_field": args.id_field, |
| "model": args.model, |
| "backend": "sentence_transformers.SentenceTransformer", |
| "max_length": max_length, |
| "max_records": args.max_records, |
| "sample_records": args.sample_records, |
| "sample_seed": args.sample_seed, |
| "split_count": args.split_count, |
| "split_index": args.split_index, |
| "pooling": "model_default", |
| "normalize_embeddings": True, |
| "text_prefix": args.text_prefix, |
| "text_template": args.text_template, |
| "prompt_name": args.prompt_name, |
| "available_prompts": getattr(model, "prompts", None), |
| "device": args.device, |
| "embedding_dtype": args.embedding_dtype, |
| "rows": total, |
| "seconds": round(elapsed, 3), |
| "rows_per_second": round(total / max(elapsed, 1e-6), 3), |
| "shards": [asdict(shard) for shard in shards], |
| } |
| (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output_dir": str(output_dir), "rows": total, "shards": len(shards), "max_length": max_length}, indent=2)) |
| return 0 |
|
|
|
|
| def load_embedding_manifest(path: Path) -> tuple[dict[str, Any], np.ndarray]: |
| manifest = json.loads(path.read_text(encoding="utf-8")) |
| arrays = [np.load(shard["path"], mmap_mode="r") for shard in manifest["shards"]] |
| if not arrays: |
| return manifest, np.zeros((0, 0), dtype=np.float32) |
| return manifest, np.concatenate(arrays, axis=0) |
|
|
|
|
| def sample_embeddings(embeddings: np.ndarray, max_rows: int | None, seed: int) -> tuple[np.ndarray, list[int]]: |
| n = int(embeddings.shape[0]) |
| if max_rows is None or max_rows >= n: |
| indices = list(range(n)) |
| else: |
| rng = random.Random(seed) |
| indices = sorted(rng.sample(range(n), max_rows)) |
| return np.asarray(embeddings[indices], dtype=np.float32), indices |
|
|
|
|
| def vendi_from_block(block: torch.Tensor) -> dict[str, float]: |
| block = torch.nn.functional.normalize(block.float(), dim=-1) |
| kernel = block @ block.T |
| eigenvalues = torch.linalg.eigvalsh(kernel).clamp_min(0) |
| total = eigenvalues.sum().clamp_min(1e-12) |
| probs = eigenvalues / total |
| entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum() |
| vendi = torch.exp(entropy) |
| return { |
| "vendi": float(vendi.item()), |
| "effective_rank": float(vendi.item()), |
| "trace": float(total.item()), |
| "max_eigen_prob": float(probs.max().item()), |
| } |
|
|
|
|
| def mean_ci(values: list[float]) -> dict[str, float]: |
| if not values: |
| return {"mean": 0.0, "ci95_low": 0.0, "ci95_high": 0.0} |
| mean = sum(values) / len(values) |
| if len(values) == 1: |
| return {"mean": mean, "ci95_low": mean, "ci95_high": mean} |
| variance = sum((value - mean) ** 2 for value in values) / (len(values) - 1) |
| half = 1.96 * math.sqrt(variance / len(values)) |
| return {"mean": mean, "ci95_low": mean - half, "ci95_high": mean + half} |
|
|
|
|
| def parse_thresholds(text: str) -> list[float]: |
| values = [] |
| for part in text.split(","): |
| part = part.strip() |
| if not part: |
| continue |
| value = float(part) |
| if not -1.0 <= value <= 1.0: |
| raise SystemExit(f"invalid cosine threshold outside [-1, 1]: {value}") |
| values.append(value) |
| if not values: |
| raise SystemExit("--thresholds must contain at least one value") |
| return values |
|
|
|
|
| def summarize_scores(scores: np.ndarray, thresholds: list[float]) -> dict[str, Any]: |
| percentiles = { |
| f"p{percentile:02d}": float(np.percentile(scores, percentile)) |
| for percentile in [1, 5, 10, 25, 50, 75, 90, 95, 99] |
| } |
| support = { |
| f"support_at_{threshold:.2f}": float(np.mean(scores >= threshold)) |
| for threshold in thresholds |
| } |
| return { |
| "mean_nn_cosine": float(np.mean(scores)), |
| "std_nn_cosine": float(np.std(scores, ddof=1)) if scores.size > 1 else 0.0, |
| **percentiles, |
| **support, |
| } |
|
|
|
|
| def summarize_support(covered: np.ndarray, density: np.ndarray, nn_cosine: np.ndarray) -> dict[str, Any]: |
| nn_distance = 1.0 - nn_cosine |
| return { |
| "coverage": float(np.mean(covered)), |
| "density": float(np.mean(density)), |
| "density_p50": float(np.percentile(density, 50)), |
| "density_p95": float(np.percentile(density, 95)), |
| "nn_cosine_mean": float(np.mean(nn_cosine)), |
| "nn_cosine_p50": float(np.percentile(nn_cosine, 50)), |
| "nn_cosine_p05": float(np.percentile(nn_cosine, 5)), |
| "nn_distance_p95": float(np.percentile(nn_distance, 95)), |
| "nn_distance_p99": float(np.percentile(nn_distance, 99)), |
| } |
|
|
|
|
| @torch.inference_mode() |
| def exact_nn_cosine( |
| query: np.ndarray, |
| gallery: np.ndarray, |
| device: str, |
| dtype: torch.dtype, |
| query_batch_size: int, |
| gallery_chunk_size: int, |
| ) -> np.ndarray: |
| if query.ndim != 2 or gallery.ndim != 2: |
| raise SystemExit("query and gallery embeddings must be 2D arrays") |
| if query.shape[1] != gallery.shape[1]: |
| raise SystemExit(f"dimension mismatch: query dim {query.shape[1]} vs gallery dim {gallery.shape[1]}") |
| if query.shape[0] == 0 or gallery.shape[0] == 0: |
| raise SystemExit("query and gallery embeddings must be non-empty") |
| if query_batch_size < 1: |
| raise SystemExit("--query-batch-size must be >= 1") |
| if gallery_chunk_size < 0: |
| raise SystemExit("--gallery-chunk-size must be >= 0") |
|
|
| scores: list[np.ndarray] = [] |
| if gallery_chunk_size == 0: |
| gallery_tensor = torch.from_numpy(gallery).to(device=device, dtype=dtype) |
| gallery_tensor = torch.nn.functional.normalize(gallery_tensor.float(), dim=-1).to(dtype) |
| gallery_t = gallery_tensor.T.contiguous() |
| for start in range(0, query.shape[0], query_batch_size): |
| query_tensor = torch.from_numpy(query[start : start + query_batch_size]).to(device=device, dtype=dtype) |
| query_tensor = torch.nn.functional.normalize(query_tensor.float(), dim=-1).to(dtype) |
| sims = query_tensor @ gallery_t |
| scores.append(sims.float().max(dim=1).values.cpu().numpy()) |
| return np.concatenate(scores, axis=0) |
|
|
| for start in range(0, query.shape[0], query_batch_size): |
| query_tensor = torch.from_numpy(query[start : start + query_batch_size]).to(device=device, dtype=dtype) |
| query_tensor = torch.nn.functional.normalize(query_tensor.float(), dim=-1).to(dtype) |
| best = torch.full((query_tensor.shape[0],), -2.0, device=device, dtype=torch.float32) |
| for gallery_start in range(0, gallery.shape[0], gallery_chunk_size): |
| gallery_tensor = torch.from_numpy(gallery[gallery_start : gallery_start + gallery_chunk_size]).to(device=device, dtype=dtype) |
| gallery_tensor = torch.nn.functional.normalize(gallery_tensor.float(), dim=-1).to(dtype) |
| sims = query_tensor @ gallery_tensor.T |
| best = torch.maximum(best, sims.float().max(dim=1).values) |
| scores.append(best.cpu().numpy()) |
| return np.concatenate(scores, axis=0) |
|
|
|
|
| @torch.inference_mode() |
| def kth_self_neighbor_cosine( |
| gallery: np.ndarray, |
| k: int, |
| device: str, |
| dtype: torch.dtype, |
| batch_size: int, |
| ) -> np.ndarray: |
| if k < 1: |
| raise SystemExit("--k must be >= 1") |
| if gallery.shape[0] <= k: |
| raise SystemExit(f"gallery rows ({gallery.shape[0]}) must be > k ({k})") |
| if batch_size < 1: |
| raise SystemExit("--gallery-batch-size must be >= 1") |
| gallery_tensor = torch.from_numpy(gallery).to(device=device, dtype=dtype) |
| gallery_tensor = torch.nn.functional.normalize(gallery_tensor.float(), dim=-1).to(dtype) |
| gallery_t = gallery_tensor.T.contiguous() |
| thresholds: list[np.ndarray] = [] |
| for start in range(0, gallery.shape[0], batch_size): |
| stop = min(start + batch_size, gallery.shape[0]) |
| sims = gallery_tensor[start:stop] @ gallery_t |
| row_indices = torch.arange(stop - start, device=device) |
| sims[row_indices, torch.arange(start, stop, device=device)] = -2.0 |
| kth = torch.topk(sims.float(), k=k, dim=1).values[:, -1] |
| thresholds.append(kth.cpu().numpy()) |
| return np.concatenate(thresholds, axis=0) |
|
|
|
|
| @torch.inference_mode() |
| def prdc_query_in_gallery_support( |
| query: np.ndarray, |
| gallery: np.ndarray, |
| gallery_thresholds: np.ndarray, |
| k: int, |
| device: str, |
| dtype: torch.dtype, |
| query_batch_size: int, |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| if query_batch_size < 1: |
| raise SystemExit("--query-batch-size must be >= 1") |
| gallery_tensor = torch.from_numpy(gallery).to(device=device, dtype=dtype) |
| gallery_tensor = torch.nn.functional.normalize(gallery_tensor.float(), dim=-1).to(dtype) |
| gallery_t = gallery_tensor.T.contiguous() |
| thresholds = torch.from_numpy(gallery_thresholds.astype(np.float32)).to(device=device) |
| covered_rows: list[np.ndarray] = [] |
| density_rows: list[np.ndarray] = [] |
| nn_rows: list[np.ndarray] = [] |
| for start in range(0, query.shape[0], query_batch_size): |
| query_tensor = torch.from_numpy(query[start : start + query_batch_size]).to(device=device, dtype=dtype) |
| query_tensor = torch.nn.functional.normalize(query_tensor.float(), dim=-1).to(dtype) |
| sims = (query_tensor @ gallery_t).float() |
| support_hits = sims >= thresholds.unsqueeze(0) |
| hit_counts = support_hits.sum(dim=1).float() |
| covered_rows.append((hit_counts > 0).cpu().numpy()) |
| density_rows.append((hit_counts / float(k)).cpu().numpy()) |
| nn_rows.append(sims.max(dim=1).values.cpu().numpy()) |
| return ( |
| np.concatenate(covered_rows, axis=0), |
| np.concatenate(density_rows, axis=0), |
| np.concatenate(nn_rows, axis=0), |
| ) |
|
|
|
|
| def vendi_main(args: argparse.Namespace) -> int: |
| manifest, embeddings = load_embedding_manifest(Path(args.manifest)) |
| n = int(embeddings.shape[0]) |
| if n == 0: |
| raise SystemExit("empty embedding cache") |
| block_size = min(args.block_size, n) |
| rng = random.Random(args.seed) |
| matrix_device = args.matrix_device or args.device |
| dtype = torch_dtype(args.dtype) |
| block_rows = [] |
| if args.sampling == "partition": |
| order = list(range(n)) |
| rng.shuffle(order) |
| index_blocks = [order[start : start + block_size] for start in range(0, n, block_size)] |
| if index_blocks and len(index_blocks[-1]) < max(2, block_size // 2): |
| |
| index_blocks[-2].extend(index_blocks[-1]) |
| index_blocks.pop() |
| else: |
| index_blocks = [ |
| rng.sample(range(n), block_size) if block_size < n else list(range(n)) |
| for _ in range(args.blocks) |
| ] |
| for block_index, indices in enumerate(index_blocks): |
| array = np.asarray(embeddings[indices], dtype=np.float32) |
| block = torch.from_numpy(array).to(matrix_device, dtype=dtype) |
| stats = vendi_from_block(block) |
| stats.update({"block_index": block_index, "block_size": len(indices)}) |
| block_rows.append(stats) |
| vendi_values = [row["vendi"] for row in block_rows] |
| payload = { |
| "embedding_manifest": args.manifest, |
| "source_model": manifest.get("model"), |
| "source_rows": n, |
| "block_size": block_size, |
| "blocks": len(block_rows), |
| "requested_blocks": args.blocks, |
| "sampling": args.sampling, |
| "seed": args.seed, |
| "device": matrix_device, |
| "summary": { |
| "vendi": mean_ci(vendi_values), |
| "max_eigen_prob": mean_ci([row["max_eigen_prob"] for row in block_rows]), |
| }, |
| "block_rows": block_rows, |
| "boundary": "Vendi is an embedding-space semantic diversity metric; it does not measure faithfulness, density, or downstream utility.", |
| } |
| output = Path(args.output) |
| output.parent.mkdir(parents=True, exist_ok=True) |
| output.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output": str(output), "vendi_mean": payload["summary"]["vendi"]["mean"], "blocks": args.blocks}, indent=2)) |
| return 0 |
|
|
|
|
| def geometry_main(args: argparse.Namespace) -> int: |
| manifest, embeddings = load_embedding_manifest(Path(args.manifest)) |
| n = int(embeddings.shape[0]) |
| if n == 0: |
| raise SystemExit("empty embedding cache") |
| rng = np.random.default_rng(args.seed) |
| take = min(args.max_rows, n) |
| indices = rng.choice(n, size=take, replace=False) if take < n else np.arange(n) |
| x = torch.from_numpy(np.asarray(embeddings[indices], dtype=np.float32)).to(args.device, dtype=torch_dtype(args.dtype)) |
| x = torch.nn.functional.normalize(x.float(), dim=-1) |
| centroid = torch.nn.functional.normalize(x.mean(dim=0, keepdim=True), dim=-1) |
| cosine_to_centroid = (x @ centroid.T).squeeze(1) |
| centered = x - x.mean(dim=0, keepdim=True) |
| cov = centered.T @ centered / max(take - 1, 1) |
| eig = torch.linalg.eigvalsh(cov).clamp_min(0) |
| eig_sum = eig.sum().clamp_min(1e-12) |
| probs = eig / eig_sum |
| spectral_entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum() |
| erank = torch.exp(spectral_entropy) |
| participation = eig_sum.square() / eig.square().sum().clamp_min(1e-12) |
| payload = { |
| "embedding_manifest": args.manifest, |
| "source_model": manifest.get("model"), |
| "source_rows": n, |
| "sample_rows": take, |
| "seed": args.seed, |
| "device": args.device, |
| "metrics": { |
| "mean_cosine_to_centroid": float(cosine_to_centroid.mean().item()), |
| "std_cosine_to_centroid": float(cosine_to_centroid.std(unbiased=True).item()) if take > 1 else 0.0, |
| "mean_pairwise_cosine_estimate": float((x.mean(dim=0).square().sum().item() * take - 1.0) / max(take - 1, 1)), |
| "cov_effective_rank": float(erank.item()), |
| "cov_participation_ratio": float(participation.item()), |
| "cov_top1_mass": float((eig.max() / eig_sum).item()), |
| "cov_top10_mass": float((eig.topk(min(10, eig.numel())).values.sum() / eig_sum).item()), |
| "cov_trace": float(eig_sum.item()), |
| }, |
| "boundary": "Geometry metrics describe embedding distribution shape: concentration, anisotropy, and effective dimensionality. They do not measure faithfulness or prompt support.", |
| } |
| output = Path(args.output) |
| output.parent.mkdir(parents=True, exist_ok=True) |
| output.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output": str(output), **payload["metrics"]}, indent=2)) |
| return 0 |
|
|
|
|
| def knn_main(args: argparse.Namespace) -> int: |
| query_manifest, query_embeddings_all = load_embedding_manifest(Path(args.query_manifest)) |
| gallery_manifest, gallery_embeddings_all = load_embedding_manifest(Path(args.gallery_manifest)) |
| query_embeddings, query_indices = sample_embeddings(query_embeddings_all, args.query_max_rows, args.seed) |
| gallery_embeddings, gallery_indices = sample_embeddings(gallery_embeddings_all, args.gallery_max_rows, args.seed + 1) |
| started = time.time() |
| scores = exact_nn_cosine( |
| query_embeddings, |
| gallery_embeddings, |
| args.device, |
| torch_dtype(args.dtype), |
| args.query_batch_size, |
| args.gallery_chunk_size, |
| ) |
| thresholds = parse_thresholds(args.thresholds) |
| payload = { |
| "query_manifest": args.query_manifest, |
| "gallery_manifest": args.gallery_manifest, |
| "query_model": query_manifest.get("model"), |
| "gallery_model": gallery_manifest.get("model"), |
| "query_source_rows": int(query_embeddings_all.shape[0]), |
| "gallery_source_rows": int(gallery_embeddings_all.shape[0]), |
| "query_rows": int(query_embeddings.shape[0]), |
| "gallery_rows": int(gallery_embeddings.shape[0]), |
| "query_seed": args.seed, |
| "gallery_seed": args.seed + 1, |
| "query_indices_preview": query_indices[:10], |
| "gallery_indices_preview": gallery_indices[:10], |
| "device": args.device, |
| "dtype": args.dtype, |
| "query_batch_size": args.query_batch_size, |
| "gallery_chunk_size": args.gallery_chunk_size, |
| "seconds": round(time.time() - started, 3), |
| "metrics": summarize_scores(scores, thresholds), |
| "boundary": ( |
| "kNN support measures nearest-neighbor coverage in the chosen embedding space. " |
| "It is directional, encoder-dependent, and not a faithfulness or density metric." |
| ), |
| } |
| if args.save_scores is not None: |
| score_path = Path(args.save_scores) |
| score_path.parent.mkdir(parents=True, exist_ok=True) |
| np.save(score_path, scores.astype(np.float32)) |
| payload["score_path"] = str(score_path) |
| output = Path(args.output) |
| output.parent.mkdir(parents=True, exist_ok=True) |
| output.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output": str(output), "query_rows": payload["query_rows"], "gallery_rows": payload["gallery_rows"], **payload["metrics"]}, indent=2)) |
| return 0 |
|
|
|
|
| def support_main(args: argparse.Namespace) -> int: |
| query_manifest, query_embeddings_all = load_embedding_manifest(Path(args.query_manifest)) |
| gallery_manifest, gallery_embeddings_all = load_embedding_manifest(Path(args.gallery_manifest)) |
| query_embeddings, query_indices = sample_embeddings(query_embeddings_all, args.query_max_rows, args.seed) |
| gallery_embeddings, gallery_indices = sample_embeddings(gallery_embeddings_all, args.gallery_max_rows, args.seed + 1) |
| started = time.time() |
| gallery_thresholds = kth_self_neighbor_cosine( |
| gallery_embeddings, |
| args.k, |
| args.device, |
| torch_dtype(args.dtype), |
| args.gallery_batch_size, |
| ) |
| covered, density, nn_cosine = prdc_query_in_gallery_support( |
| query_embeddings, |
| gallery_embeddings, |
| gallery_thresholds, |
| args.k, |
| args.device, |
| torch_dtype(args.dtype), |
| args.query_batch_size, |
| ) |
| payload = { |
| "query_manifest": args.query_manifest, |
| "gallery_manifest": args.gallery_manifest, |
| "query_model": query_manifest.get("model"), |
| "gallery_model": gallery_manifest.get("model"), |
| "query_source_rows": int(query_embeddings_all.shape[0]), |
| "gallery_source_rows": int(gallery_embeddings_all.shape[0]), |
| "query_rows": int(query_embeddings.shape[0]), |
| "gallery_rows": int(gallery_embeddings.shape[0]), |
| "query_seed": args.seed, |
| "gallery_seed": args.seed + 1, |
| "query_indices_preview": query_indices[:10], |
| "gallery_indices_preview": gallery_indices[:10], |
| "k": args.k, |
| "device": args.device, |
| "dtype": args.dtype, |
| "query_batch_size": args.query_batch_size, |
| "gallery_batch_size": args.gallery_batch_size, |
| "seconds": round(time.time() - started, 3), |
| "gallery_thresholds": { |
| "mean_kth_neighbor_cosine": float(np.mean(gallery_thresholds)), |
| "p05_kth_neighbor_cosine": float(np.percentile(gallery_thresholds, 5)), |
| "p50_kth_neighbor_cosine": float(np.percentile(gallery_thresholds, 50)), |
| "p95_kth_neighbor_cosine": float(np.percentile(gallery_thresholds, 95)), |
| }, |
| "metrics": summarize_support(covered, density, nn_cosine), |
| "boundary": ( |
| "P-in-C support is a PRDC-style embedding-manifold estimate: query points are covered " |
| "when they fall inside at least one gallery kNN ball. It measures support in the chosen " |
| "embedding space, not image faithfulness or overall caption quality." |
| ), |
| } |
| if args.save_scores is not None: |
| score_path = Path(args.save_scores) |
| score_path.parent.mkdir(parents=True, exist_ok=True) |
| np.savez_compressed( |
| score_path, |
| covered=covered.astype(np.bool_), |
| density=density.astype(np.float32), |
| nn_cosine=nn_cosine.astype(np.float32), |
| gallery_thresholds=gallery_thresholds.astype(np.float32), |
| ) |
| payload["score_path"] = str(score_path) |
| output = Path(args.output) |
| output.parent.mkdir(parents=True, exist_ok=True) |
| output.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output": str(output), "query_rows": payload["query_rows"], "gallery_rows": payload["gallery_rows"], **payload["metrics"]}, indent=2)) |
| return 0 |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| if args.cmd == "inspect": |
| return inspect_models(args) |
| if args.cmd == "encode": |
| return encode_main(args) |
| if args.cmd == "encode-bge-m3": |
| return encode_bge_m3_main(args) |
| if args.cmd == "encode-sentence-transformer": |
| return encode_sentence_transformer_main(args) |
| if args.cmd == "vendi": |
| return vendi_main(args) |
| if args.cmd == "geometry": |
| return geometry_main(args) |
| if args.cmd == "knn": |
| return knn_main(args) |
| if args.cmd == "support": |
| return support_main(args) |
| raise AssertionError(args.cmd) |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|