| import entrypoint_setup |
|
|
| import random |
| import tempfile |
| import os |
| import torch |
|
|
| from esm2.modeling_fastesm import FastEsmForMaskedLM |
| from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM |
| from e1_fastplms.modeling_e1 import E1ForMaskedLM |
| from dplm_fastplms.modeling_dplm import DPLMForMaskedLM |
| from dplm2_fastplms.modeling_dplm2 import ( |
| DPLM2ForMaskedLM, |
| _has_packed_multimodal_layout, |
| _normalize_dplm2_input_ids, |
| ) |
| from embedding_mixin import parse_fasta |
|
|
|
|
| CANONICAL_AAS = "ACDEFGHIKLMNPQRSTVWY" |
| SEED = 42 |
| DEFAULT_BATCH_SIZE = 4 |
| MAX_EMBED_LEN = 128 |
|
|
|
|
| |
| MODEL_CONFIGS = [ |
| ("ESM2", FastEsmForMaskedLM, "Synthyra/ESM2-8M", True), |
| ("ESM++", ESMplusplusForMaskedLM, "Synthyra/ESMplusplus_small", True), |
| ("E1", E1ForMaskedLM, "Synthyra/Profluent-E1-150M", False), |
| ("DPLM", DPLMForMaskedLM, "Synthyra/DPLM-150M", True), |
| ("DPLM2", DPLM2ForMaskedLM, "Synthyra/DPLM2-150M", True), |
| ] |
|
|
|
|
| def test_parse_fasta() -> None: |
| """Test parse_fasta with single-line and multi-line sequences.""" |
| fasta_content = ( |
| ">seq1 a simple protein\n" |
| "MKTLLLTLVVVTIVCLDLGYT\n" |
| ">seq2 multi-line sequence\n" |
| "ACDEFGHIKL\n" |
| "MNPQRSTVWY\n" |
| ">seq3 another entry\n" |
| "MALWMRLLPLLALL\n" |
| ) |
| expected = [ |
| "MKTLLLTLVVVTIVCLDLGYT", |
| "ACDEFGHIKLMNPQRSTVWY", |
| "MALWMRLLPLLALL", |
| ] |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) as f: |
| f.write(fasta_content) |
| tmp_path = f.name |
| parsed = parse_fasta(tmp_path) |
| os.unlink(tmp_path) |
| assert parsed == expected, f"parse_fasta mismatch:\n got: {parsed}\n expected: {expected}" |
| print("test_parse_fasta: OK") |
|
|
|
|
| class FixedLengthTokenizer: |
| """Wraps a tokenizer so every call pads to exactly MAX_EMBED_LEN tokens. |
| |
| Both batch=1 and batch=N therefore receive tensors of the same shape, |
| keeping max_seqlen_in_batch identical and eliminating floating-point |
| variability from different softmax vector lengths / flash-attention tile sizes. |
| """ |
| def __init__(self, tokenizer, max_length: int = MAX_EMBED_LEN): |
| self._tok = tokenizer |
| self.max_length = max_length |
|
|
| def __call__(self, sequences, **kwargs): |
| return self._tok( |
| sequences, |
| return_tensors="pt", |
| padding="max_length", |
| max_length=self.max_length, |
| truncation=True, |
| ) |
|
|
|
|
| def random_sequences(n: int, min_len: int = 8, max_len: int = 64) -> list[str]: |
| """Variable-length sequences; used for the NaN test.""" |
| return [ |
| "M" + "".join(random.choices(CANONICAL_AAS, k=random.randint(min_len, max_len))) |
| for _ in range(n) |
| ] |
|
|
|
|
| def random_sequences_fixed_len(n: int, length: int = 64) -> list[str]: |
| """Fixed-length sequences; used for the match test with E1 (sequence mode).""" |
| return [ |
| "M" + "".join(random.choices(CANONICAL_AAS, k=length - 1)) |
| for _ in range(n) |
| ] |
|
|
|
|
| def assert_no_nan(embeddings: dict[str, torch.Tensor], label: str) -> None: |
| for seq, emb in embeddings.items(): |
| assert not torch.isnan(emb).any(), ( |
| f"[{label}] NaN found in embedding for sequence '{seq[:20]}...'" |
| ) |
|
|
|
|
| def assert_embeddings_match( |
| a: dict[str, torch.Tensor], |
| b: dict[str, torch.Tensor], |
| label: str, |
| atol: float = 5e-3, |
| ) -> None: |
| """Compare real-token embeddings from two runs. |
| |
| full_embeddings=True already strips padding via emb[mask.bool()], so both |
| dicts contain only non-pad token rows and the comparison is over those rows. |
| """ |
| assert set(a) == set(b), f"[{label}] Key sets differ between batch and single runs" |
| for seq in a: |
| ea, eb = a[seq].float(), b[seq].float() |
| assert ea.shape == eb.shape, ( |
| f"[{label}] Shape mismatch for '{seq[:20]}': {ea.shape} vs {eb.shape}" |
| ) |
| max_diff = (ea - eb).abs().max().item() |
| assert max_diff <= atol, ( |
| f"[{label}] Max abs diff {max_diff:.5f} > {atol} for '{seq[:20]}'" |
| ) |
|
|
|
|
| def test_dplm2_multimodal_layout_guard() -> None: |
| plain_sequence_type_ids = torch.tensor([ |
| [1, 1, 1, 1, 1, 1, 0, 2], |
| [1, 1, 1, 1, 1, 0, 2, 2], |
| ]) |
| packed_multimodal_type_ids = torch.tensor([ |
| [1, 1, 1, 2, 0, 0, 0, 2], |
| [1, 1, 2, 2, 0, 0, 2, 2], |
| ]) |
| mismatched_multimodal_type_ids = torch.tensor([ |
| [1, 1, 1, 2, 0, 0, 2, 2], |
| ]) |
|
|
| assert not _has_packed_multimodal_layout(plain_sequence_type_ids, aa_type=1, struct_type=0, pad_type=2) |
| assert _has_packed_multimodal_layout(packed_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2) |
| assert not _has_packed_multimodal_layout(mismatched_multimodal_type_ids, aa_type=1, struct_type=0, pad_type=2) |
| print("test_dplm2_multimodal_layout_guard: OK") |
|
|
|
|
| def test_dplm2_special_token_normalization() -> None: |
| input_ids = torch.tensor([[8231, 5, 23, 13, 8229, 1, 8232, -100]]) |
| normalized_input_ids = _normalize_dplm2_input_ids(input_ids, vocab_size=8229) |
| expected = torch.tensor([[0, 5, 23, 13, 2, 1, 32, -100]]) |
| assert torch.equal(normalized_input_ids, expected), ( |
| f"DPLM2 special-token normalization mismatch:\n" |
| f" got: {normalized_input_ids.tolist()}\n" |
| f" expected: {expected.tolist()}" |
| ) |
| print("test_dplm2_special_token_normalization: OK") |
|
|
|
|
| def test_model(name: str, model_cls, model_path: str, use_model_tokenizer: bool, batch_size: int) -> None: |
| print(f"\n--- {name} ({model_path}) ---") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model = model_cls.from_pretrained( |
| model_path, |
| dtype=torch.bfloat16, |
| device_map=device, |
| trust_remote_code=True, |
| ).eval() |
|
|
| if use_model_tokenizer: |
| |
| |
| |
| tokenizer = FixedLengthTokenizer(model.tokenizer) |
| sequences = random_sequences(n=8) |
| else: |
| |
| |
| tokenizer = None |
| sequences = random_sequences_fixed_len(n=8) |
|
|
| nan_kwargs = dict( |
| tokenizer=tokenizer, |
| full_embeddings=True, |
| embed_dtype=torch.bfloat16, |
| save=False, |
| ) |
|
|
| |
| |
| |
| nan_embs = model.embed_dataset(sequences=sequences, batch_size=batch_size, **nan_kwargs) |
| assert_no_nan(nan_embs, f"{name} NaN check batch_size={batch_size}") |
| shapes = [tuple(e.shape) for e in list(nan_embs.values())[:3]] |
| print(f" NaN check batch_size={batch_size}: OK sample shapes={shapes}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if not use_model_tokenizer: |
| return |
|
|
| model.to(torch.float32) |
| batch_embs = model.embed_dataset( |
| sequences=sequences, batch_size=batch_size, |
| tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False, |
| ) |
| single_embs = model.embed_dataset( |
| sequences=sequences, batch_size=1, |
| tokenizer=tokenizer, full_embeddings=True, embed_dtype=torch.float32, save=False, |
| ) |
| assert_no_nan(batch_embs, f"{name} match test batch_size={batch_size}") |
| assert_no_nan(single_embs, f"{name} match test batch_size=1") |
| assert_embeddings_match(batch_embs, single_embs, name) |
| print(f" Match test batch_size={batch_size} vs 1: OK (non-pad tokens only)") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="Test embed_dataset produces no NaN with batch_size > 1.") |
| parser.add_argument("--models", nargs="+", default=["ESM2", "ESM++", "E1", "DPLM", "DPLM2"]) |
| parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE) |
| args = parser.parse_args() |
|
|
| random.seed(SEED) |
| test_parse_fasta() |
| test_dplm2_multimodal_layout_guard() |
| test_dplm2_special_token_normalization() |
|
|
| valid_names = {cfg[0] for cfg in MODEL_CONFIGS} |
| for name in args.models: |
| assert name in valid_names, f"Unknown model '{name}'. Choose from {sorted(valid_names)}" |
|
|
| configs_by_name = {cfg[0]: cfg for cfg in MODEL_CONFIGS} |
| for model_name in args.models: |
| name, model_cls, model_path, use_model_tokenizer = configs_by_name[model_name] |
| test_model(name, model_cls, model_path, use_model_tokenizer, args.batch_size) |
|
|
| print("\nAll tests passed!") |
|
|