| """ |
| Embedding Test Suite CLI |
| |
| Tests embedding quality by sampling sequences from EC dataset (default), |
| embedding them with various pooling methods, and reporting statistics |
| on distribution, NaNs, and sparsity. |
| """ |
|
|
| import json |
| import argparse |
| import math |
| import random |
| import numpy as np |
| import torch |
| from typing import Dict, List, Optional |
|
|
| try: |
| from data.data_mixin import DataMixin, DataArguments |
| from embedder import Embedder, EmbeddingArguments |
| from base_models.get_base_models import standard_models |
| from seed_utils import set_global_seed, get_global_seed |
| from utils import print_message |
| except ImportError: |
| from ..data.data_mixin import DataMixin, DataArguments |
| from ..embedder import Embedder, EmbeddingArguments |
| from ..base_models.get_base_models import standard_models |
| from ..seed_utils import set_global_seed, get_global_seed |
| from ..utils import print_message |
|
|
|
|
| |
| DEFAULT_TEST_DATASETS = [ |
| 'EC', |
| ] |
|
|
|
|
| seed = get_global_seed() |
| if seed is not None: |
| random.seed(seed) |
| np.random.seed(seed) |
|
|
| def load_and_sample_sequences( |
| dataset_names: List[str], |
| sample_frac: float = 0.1, |
| max_length: int = 1024, |
| trim: bool = False |
| ) -> Dict[str, List[str]]: |
| """ |
| Load datasets and sample sequences from them. |
| |
| Args: |
| dataset_names: List of dataset names to load |
| sample_frac: Fraction of sequences to sample (default 0.1 = 10%) |
| max_length: Maximum sequence length |
| trim: Whether to trim sequences to max_length |
| |
| Returns: |
| Dictionary mapping dataset names to lists of sampled sequences |
| """ |
| dataset_seqs = {} |
| |
| for dataset_name in dataset_names: |
| print_message(f"Loading dataset: {dataset_name}") |
| |
| try: |
| |
| data_args = DataArguments( |
| data_names=[dataset_name], |
| max_length=max_length, |
| trim=trim |
| ) |
| data_mixin = DataMixin(data_args) |
| datasets, all_seqs = data_mixin.get_data() |
| |
| |
| sequences = [] |
| if dataset_name in datasets: |
| train_set, valid_set, test_set, _, _, ppi = datasets[dataset_name] |
| |
| if ppi: |
| |
| sequences.extend(list(train_set['SeqA'])) |
| sequences.extend(list(train_set['SeqB'])) |
| sequences.extend(list(valid_set['SeqA'])) |
| sequences.extend(list(valid_set['SeqB'])) |
| sequences.extend(list(test_set['SeqA'])) |
| sequences.extend(list(test_set['SeqB'])) |
| else: |
| sequences.extend(list(train_set['seqs'])) |
| sequences.extend(list(valid_set['seqs'])) |
| sequences.extend(list(test_set['seqs'])) |
| else: |
| |
| sequences = list(all_seqs) |
| |
| |
| sequences = list(set(sequences)) |
| n_samples = max(1, math.ceil(len(sequences) * sample_frac)) |
| sampled = random.sample(sequences, min(n_samples, len(sequences))) |
| dataset_seqs[dataset_name] = sampled |
| |
| print_message(f"Sampled {len(sampled)} sequences from {len(sequences)} total") |
| |
| except Exception as e: |
| print_message(f"Error loading dataset {dataset_name}: {e}") |
| continue |
| |
| return dataset_seqs |
|
|
|
|
| def compute_diagnostics(embeddings: torch.Tensor, zero_eps: float = 1e-8) -> Dict[str, float]: |
| emb = embeddings.detach().float().cpu().numpy() |
| flat = emb.ravel() |
|
|
| is_nan = np.isnan(flat) |
| is_inf = np.isinf(flat) |
| is_finite = np.isfinite(flat) |
|
|
| finite = flat[is_finite] |
| if finite.size == 0: |
| |
| return { |
| "n_samples": int(emb.shape[0]), |
| "embedding_dim": int(emb.shape[1]), |
| "finite_count": 0, |
| "nan_count": int(is_nan.sum()), |
| "inf_count": int(is_inf.sum()), |
| } |
|
|
| near_zero = np.abs(finite) < zero_eps |
|
|
| sample_l2 = np.linalg.norm(emb, axis=1) |
|
|
| return { |
| "n_samples": int(emb.shape[0]), |
| "embedding_dim": int(emb.shape[1]), |
|
|
| "finite_count": int(finite.size), |
| "finite_fraction": float(finite.size / flat.size), |
|
|
| "nan_count": int(is_nan.sum()), |
| "nan_fraction": float(is_nan.mean()), |
|
|
| "inf_count": int(is_inf.sum()), |
| "inf_fraction": float(is_inf.mean()), |
|
|
| "zero_eps": float(zero_eps), |
| "near_zero_count": int(near_zero.sum()), |
| "near_zero_fraction": float(near_zero.mean()), |
|
|
| "mean": float(np.mean(finite)), |
| "std": float(np.std(finite)), |
| "min": float(np.min(finite)), |
| "max": float(np.max(finite)), |
| "p25": float(np.percentile(finite, 25)), |
| "p50": float(np.percentile(finite, 50)), |
| "p75": float(np.percentile(finite, 75)), |
| "p95": float(np.percentile(finite, 95)), |
| "p99": float(np.percentile(finite, 99)), |
|
|
| "mean_l2": float(np.mean(sample_l2)), |
| "std_l2": float(np.std(sample_l2)), |
| "p95_l2": float(np.percentile(sample_l2, 95)), |
| } |
|
|
|
|
| def embed_and_diagnose( |
| sequences: List[str], |
| model_name: str, |
| pooling_types: List[str], |
| batch_size: int = 16, |
| num_workers: int = 0 |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Embed sequences and compute diagnostics for each pooling type. |
| |
| Args: |
| sequences: List of sequences to embed |
| model_name: Name of the model to use |
| pooling_types: List of pooling types to test |
| batch_size: Batch size for embedding |
| num_workers: Number of workers for data loading |
| |
| Returns: |
| Dictionary mapping pooling types to their diagnostics |
| """ |
| print_message(f"Embedding {len(sequences)} sequences with {model_name}") |
| |
| |
| pooling_list = {} |
| for pool_type in pooling_types: |
| |
| if ',' in pool_type: |
| |
| pool_list = [p.strip() for p in pool_type.split(',')] |
| pooling_list[pool_type] = pool_list |
| else: |
| |
| pooling_list[pool_type] = [pool_type] |
| |
| results = {} |
| |
| |
| print_message(f"Loading model: {model_name}") |
| from base_models.get_base_models import get_base_model |
| model, tokenizer = get_base_model(model_name) |
| |
| for pool_type, pool_list in pooling_list.items(): |
| print_message(f"Testing pooling: {pool_type} (types: {pool_list})") |
| |
| |
| embedder_args = EmbeddingArguments( |
| embedding_batch_size=batch_size, |
| embedding_num_workers=num_workers, |
| download_embeddings=False, |
| matrix_embed=False, |
| embedding_pooling_types=pool_list, |
| save_embeddings=False, |
| embed_dtype=torch.float32, |
| sql=False, |
| embedding_save_dir='embeddings' |
| ) |
| |
| embedder = Embedder(embedder_args, sequences) |
| |
| try: |
| |
| to_embed, save_path, embeddings_dict = embedder._read_embeddings_from_disk(model_name) |
| |
| if len(to_embed) > 0: |
| result = embedder._embed_sequences( |
| to_embed, save_path, model, tokenizer, embeddings_dict |
| ) |
| if result is not None: |
| embeddings_dict = result |
| |
| if embeddings_dict is None or len(embeddings_dict) == 0: |
| print_message(f"Warning: No embeddings returned for {model_name} with {pool_type}") |
| continue |
| |
| embedding_tensors = [] |
| for seq in sequences: |
| if seq in embeddings_dict: |
| embedding_tensors.append(embeddings_dict[seq]) |
| |
| if len(embedding_tensors) == 0: |
| print_message(f"Error: No embeddings found for {pool_type}") |
| continue |
| |
| embeddings = torch.stack(embedding_tensors) |
| |
| diagnostics = compute_diagnostics(embeddings) |
| results[pool_type] = diagnostics |
| |
| except Exception as e: |
| print_message(f"Error embedding with {model_name} using {pool_type}: {e}") |
| import traceback |
| traceback.print_exc() |
| continue |
| |
| return results |
|
|
|
|
| def run_test_suite( |
| dataset_names: Optional[List[str]] = None, |
| model_names: Optional[List[str]] = None, |
| pooling_methods: List[str] = ['cls', 'mean,var'], |
| sample_frac: float = 0.1, |
| batch_size: int = 16, |
| num_workers: int = 0 |
| ) -> Dict: |
| """ |
| Run the embedding test suite. |
| """ |
| if dataset_names is None: |
| dataset_names = DEFAULT_TEST_DATASETS |
| |
| if model_names is None: |
| model_names = standard_models |
| |
| print_message(f"Running embedding test suite") |
| print_message(f"Datasets: {dataset_names}") |
| print_message(f"Models: {model_names}") |
| print_message(f"Pooling methods: {pooling_methods}") |
| print_message(f"Sample fraction: {sample_frac}") |
| |
| dataset_seqs = load_and_sample_sequences(dataset_names, sample_frac=sample_frac) |
| |
| if len(dataset_seqs) == 0: |
| print_message("Error: No sequences loaded") |
| return {} |
| |
| all_results = {} |
| |
| for dataset_name, sequences in dataset_seqs.items(): |
| print_message(f"\nProcessing dataset: {dataset_name}") |
| all_results[dataset_name] = {} |
| |
| for model_name in model_names: |
| print_message(f"Model: {model_name}") |
| model_results = embed_and_diagnose( |
| sequences, |
| model_name, |
| pooling_methods, |
| batch_size=batch_size, |
| num_workers=num_workers |
| ) |
| |
| if model_results: |
| all_results[dataset_name][model_name] = model_results |
| |
| print_table_results(all_results) |
| print_json_results(all_results) |
| |
| return all_results |
|
|
|
|
| def print_table_results(results: Dict): |
| """Print results in table format.""" |
| print("\n" + "="*100) |
| print("EMBEDDING TEST SUITE RESULTS") |
| print("="*100) |
| |
| for dataset_name, dataset_results in results.items(): |
| print(f"\nDataset: {dataset_name}") |
| print("-" * 100) |
| |
| for model_name, model_results in dataset_results.items(): |
| print(f"\n Model: {model_name}") |
| |
| for pool_type, diagnostics in model_results.items(): |
| print(f"\nPooling: {pool_type}") |
| print(f"Samples: {diagnostics['n_samples']}, Dim: {diagnostics['embedding_dim']}") |
| print(f"Mean: {diagnostics['mean']:.6f}, Std: {diagnostics['std']:.6f}") |
| print(f"Min: {diagnostics['min']:.6f}, Max: {diagnostics['max']:.6f}") |
| print(f"Percentiles: P25={diagnostics['p25']:.6f}, P50={diagnostics['p50']:.6f}, " |
| f"P75={diagnostics['p75']:.6f}, P95={diagnostics['p95']:.6f}, P99={diagnostics['p99']:.6f}") |
| print(f"NaN: {diagnostics['nan_count']} ({diagnostics['nan_fraction']*100:.2f}%)") |
| if 'near_zero_count' in diagnostics: |
| print(f"Near zeros: {diagnostics['near_zero_count']} ({diagnostics['near_zero_fraction']*100:.2f}%)") |
| print(f"Inf: {diagnostics['inf_count']} ({diagnostics['inf_fraction']*100:.2f}%)") |
| |
| |
| anomalies = [] |
| if diagnostics['nan_fraction'] > 0: |
| anomalies.append(f"NaNs detected ({diagnostics['nan_fraction']*100:.2f}%)") |
| if 'near_zero_fraction' in diagnostics and diagnostics['near_zero_fraction'] > 0.2: |
| anomalies.append(f"High sparsity ({diagnostics['near_zero_fraction']*100:.2f}%)") |
| if diagnostics['inf_fraction'] > 0: |
| anomalies.append(f"Infs detected ({diagnostics['inf_fraction']*100:.2f}%)") |
| if abs(diagnostics['mean']) > 100: |
| anomalies.append(f"Extreme mean ({diagnostics['mean']:.2f})") |
| if diagnostics['std'] > 100: |
| anomalies.append(f"Extreme std ({diagnostics['std']:.2f})") |
| |
| if anomalies: |
| print(f"Anomalies: {', '.join(anomalies)}") |
| else: |
| print(f"No anomalies detected") |
|
|
|
|
| def print_json_results(results: Dict): |
| """Print results in JSON format.""" |
| print("\n" + "="*50) |
| print("JSON RESULTS") |
| print("="*50) |
| print(json.dumps(results, indent=2)) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description='Embedding Test Suite - Test embedding quality across datasets and models' |
| ) |
| |
| parser.add_argument( |
| '--datasets', |
| nargs='+', |
| default=None, |
| help=f'List of dataset names to test (default: EC)' |
| ) |
| |
| parser.add_argument( |
| '--model_names', |
| nargs='+', |
| default=None, |
| help='List of model names to test (default: all currently_supported_models)' |
| ) |
| |
| parser.add_argument( |
| '--pooling_methods', |
| nargs='+', |
| default=['cls', 'mean,var'], |
| help='List of pooling methods to test (default: mean, var, cls, parti, mean,var)' |
| ) |
| |
| parser.add_argument( |
| '--sample_frac', |
| type=float, |
| default=0.1, |
| help='Fraction of sequences to sample from each dataset (default: 0.1)' |
| ) |
| |
| parser.add_argument( |
| '--batch_size', |
| type=int, |
| default=16, |
| help='Batch size for embedding (default: 16)' |
| ) |
| |
| parser.add_argument( |
| '--num_workers', |
| type=int, |
| default=0, |
| help='Number of workers for data loading (default: 0)' |
| ) |
| |
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=None, |
| help='Random seed for reproducibility' |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| if args.seed is not None: |
| set_global_seed(args.seed) |
| |
| |
| results = run_test_suite( |
| dataset_names=args.datasets, |
| model_names=args.model_names, |
| pooling_methods=args.pooling_methods, |
| sample_frac=args.sample_frac, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers |
| ) |
| |
| return results |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|