genomenet Claude Opus 4.7 (1M context) commited on
Commit
54f342f
·
1 Parent(s): 10564c4

Switch Twin to aspect-specific checkpoints with runtime switcher

Browse files

Replaced the old collapsed-training Twin checkpoint with the newer BP/CC/MF
aspect-specific family (train_point_{ASPECT}_20251221_std_ft_bs32ga4), which
has real dynamic range in its distance output (cos 0.59-0.96 vs old 0.999).

- Three checkpoints hosted in one HF model repo: genomenet/twin-point-1024
- Added BP/CC/MF radio button to the Compare tab (default BP)
- Lazy-load: only one aspect cached at a time (memory constraint on cpu-basic)
- Each aspect checkpoint contains fine-tuned ESM2 backbone (~2.7 GB per file)
- Twin output dim back to 1024 (these use projection_dim=512, unlike the old one)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. README.md +14 -7
  2. app.py +73 -40
  3. twin_model.py +2 -1
README.md CHANGED
@@ -17,7 +17,7 @@ Compare protein sequence embeddings from ESM2 (baseline) and Twin network (fine-
17
  | Model | Parameters | Embedding | Training |
18
  |-------|------------|-----------|----------|
19
  | ESM2 | 650M | 1280-dim | Pretrained on UniRef50 |
20
- | Twin | ESM2 + custom | 512-dim | Contrastive on Resnik GO-semantic similarity (all aspects) |
21
 
22
  ## Usage
23
 
@@ -37,12 +37,19 @@ env vars.
37
 
38
  ## Twin model
39
 
40
- Two-tower contrastive encoder (AA-vocab Transformer + frozen ESM2 backbone),
41
- trained on Resnik GO-semantic similarity. The 2.5 GB checkpoint is hosted at
42
- [`genomenet/twin-aa-esm-resnik-1024`](https://huggingface.co/genomenet/twin-aa-esm-resnik-1024)
43
- and downloaded on Space startup. Override with `TWIN_REPO_ID` /
44
- `TWIN_CHECKPOINT_FILE` / `TWIN_ESM_BACKBONE` env vars. Architecture code
45
- lives in `twin_model.py` alongside `app.py`.
 
 
 
 
 
 
 
46
 
47
  ## Acknowledgements
48
 
 
17
  | Model | Parameters | Embedding | Training |
18
  |-------|------------|-----------|----------|
19
  | ESM2 | 650M | 1280-dim | Pretrained on UniRef50 |
20
+ | Twin | ESM2 + custom | 1024-dim | Resnik-contrastive fine-tune, one checkpoint per GO aspect (BP/CC/MF) |
21
 
22
  ## Usage
23
 
 
37
 
38
  ## Twin model
39
 
40
+ Two-tower contrastive encoder (custom AA Transformer + fine-tuned ESM2 backbone),
41
+ trained on Resnik GO-semantic similarity. **Three aspect-specific checkpoints**
42
+ (~2.7 GB each) are hosted at
43
+ [`genomenet/twin-point-1024`](https://huggingface.co/genomenet/twin-point-1024):
44
+
45
+ - `bp_cp_best.pt` Biological Process
46
+ - `cc_cp_best.pt` — Cellular Component
47
+ - `mf_cp_best.pt` — Molecular Function
48
+
49
+ The Space loads one aspect at a time (cpu-basic memory budget); switching aspects
50
+ evicts the previous model and downloads the next (~15 s from disk-cache after
51
+ first use). Override via `TWIN_REPO_ID` / `TWIN_DEFAULT_ASPECT` / `TWIN_ESM_BACKBONE`
52
+ env vars. Architecture code lives in `twin_model.py` alongside `app.py`.
53
 
54
  ## Acknowledgements
55
 
app.py CHANGED
@@ -20,7 +20,7 @@ from plotly.subplots import make_subplots
20
  # Model config
21
  ESM2_MODEL = "esm2_t33_650M_UR50D" # 650M params, 1280-dim
22
  ESM2_DIM = 1280
23
- TWIN_DIM = 512 # 2 * projection_dim (256), two-tower concat; 1024 in the run name is seq_len
24
 
25
  # FAISS index config (UniRef50 GO-annotated, ESM2 650M embeddings)
26
  FAISS_REPO_ID = os.environ.get("FAISS_REPO_ID", "genomenet/esm2-uniref50-faiss")
@@ -29,16 +29,21 @@ FAISS_IDS_FILE = "ids.npy"
29
  FAISS_METADATA_FILE = "metadata.json"
30
  FAISS_NPROBE = int(os.environ.get("FAISS_NPROBE", "32"))
31
 
32
- # Twin model config (checkpoint hosted as HF model repo)
33
- TWIN_REPO_ID = os.environ.get("TWIN_REPO_ID", "genomenet/twin-aa-esm-resnik-1024")
34
- TWIN_CHECKPOINT_FILE = os.environ.get("TWIN_CHECKPOINT_FILE", "model_aa_best.pt")
 
 
 
 
 
35
  TWIN_ESM_BACKBONE = os.environ.get("TWIN_ESM_BACKBONE", "facebook/esm2_t33_650M_UR50D")
36
 
37
  # Model cache
38
  _esm2_model = None
39
  _esm2_alphabet = None
40
- _twin_model = None
41
- _twin_seq_len = None
42
  _faiss_index = None
43
  _faiss_ids = None
44
  _faiss_metadata = None
@@ -56,29 +61,47 @@ def get_esm2():
56
  print("ESM2 loaded.")
57
  return _esm2_model, _esm2_alphabet
58
 
59
- def get_twin():
60
- """Download + load the fine-tuned Twin model (two-tower contrastive encoder)."""
61
- global _twin_model, _twin_seq_len
62
- if _twin_model is not None:
63
- return _twin_model, _twin_seq_len
64
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  import torch
66
  from huggingface_hub import hf_hub_download
67
  from twin_model import load_twin_model
68
 
69
- print(f"Downloading Twin checkpoint from {TWIN_REPO_ID}/{TWIN_CHECKPOINT_FILE}...")
70
- ckpt_path = hf_hub_download(
71
- repo_id=TWIN_REPO_ID,
72
- filename=TWIN_CHECKPOINT_FILE,
73
- repo_type="model",
74
- )
 
 
 
 
 
 
 
75
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
  model, seq_len, emb_dim = load_twin_model(ckpt_path, device, TWIN_ESM_BACKBONE)
77
  if emb_dim != TWIN_DIM:
78
- print(f" WARN Twin output dim is {emb_dim}, expected {TWIN_DIM}")
79
- _twin_model = model
80
- _twin_seq_len = seq_len
81
- return _twin_model, _twin_seq_len
 
 
82
 
83
  def get_faiss():
84
  """Download + load FAISS index and UniRef50 id mapping from HF dataset repo."""
@@ -154,11 +177,11 @@ def embed_esm2(sequence):
154
  return embedding
155
 
156
  @torch.no_grad()
157
- def embed_twin(sequence):
158
- """Compute Twin embedding: concat(custom_proj, esm_proj), (TWIN_DIM,) float32."""
159
  from twin_model import ensure_aa_sequence, preprocess_sequences_batch
160
 
161
- model, seq_len = get_twin()
162
  device = next(model.parameters()).device
163
  cleaned = ensure_aa_sequence(sequence)
164
  input_ids = preprocess_sequences_batch([cleaned], seq_len, device)
@@ -297,7 +320,7 @@ def create_distribution_plot(esm2_emb, twin_emb):
297
  # Example protein (human insulin)
298
  EXAMPLE_PROTEIN = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
299
 
300
- def process(sequence: str, top_k: int = 10):
301
  """Process protein sequence, compare embeddings, and search FAISS."""
302
  sequence = strip_fasta_header(sequence.strip())
303
 
@@ -308,7 +331,7 @@ def process(sequence: str, top_k: int = 10):
308
 
309
  # Compute embeddings
310
  esm2_emb = embed_esm2(sequence)
311
- twin_emb = embed_twin(sequence)
312
 
313
  # Compute stats
314
  esm2_stats = compute_stats(esm2_emb)
@@ -325,7 +348,7 @@ def process(sequence: str, top_k: int = 10):
325
 
326
  summary = f"""### Results
327
 
328
- | | ESM2 | Twin |
329
  |---|---|---|
330
  | Dimension | {esm2_stats['dim']} | {twin_stats['dim']} |
331
  | L2 Norm | {esm2_stats['l2_norm']:.2f} | {twin_stats['l2_norm']:.2f} |
@@ -337,7 +360,7 @@ Sequence: {len(sequence)} aa
337
 
338
  # Create visualizations
339
  esm2_heatmap = create_embedding_heatmap(esm2_emb, "ESM2 Embedding")
340
- twin_heatmap = create_embedding_heatmap(twin_emb, "Twin Embedding")
341
  comparison_plot = create_comparison_plot(esm2_stats, twin_stats)
342
  distribution_plot = create_distribution_plot(esm2_emb, twin_emb)
343
 
@@ -364,6 +387,13 @@ with gr.Blocks(
364
  minimum=1, maximum=50, value=10, step=1,
365
  label="Nearest neighbors (top-k)"
366
  )
 
 
 
 
 
 
 
367
  btn = gr.Button("compare embeddings", variant="primary")
368
  output = gr.Markdown()
369
  with gr.Row():
@@ -376,7 +406,7 @@ with gr.Blocks(
376
 
377
  with gr.Row():
378
  esm2_heatmap = gr.Plot(label="ESM2 Embedding (1280-dim)")
379
- twin_heatmap = gr.Plot(label="Twin Embedding (512-dim)")
380
 
381
  gr.Markdown("### Nearest UniRef50 neighbors (ESM2 embedding, cosine)")
382
  hits_table = gr.Dataframe(
@@ -388,7 +418,7 @@ with gr.Blocks(
388
 
389
  btn.click(
390
  process,
391
- inputs=[seq_input, top_k_slider],
392
  outputs=[output, esm2_download, twin_download, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_table],
393
  api_name="compare"
394
  )
@@ -405,12 +435,13 @@ client = Client("genomenet/functional-distance")
405
  result = client.predict(
406
  sequence="MALWMRLLPLLALLALWG...", # protein sequence
407
  top_k=10,
 
408
  api_name="/compare"
409
  )
410
 
411
  summary, esm2_path, twin_path, *plots, hits = result
412
  esm2_emb = np.load(esm2_path) # (1280,)
413
- twin_emb = np.load(twin_path) # (512,)
414
  # hits: DataFrame with columns [rank, uniref50_id, cosine, uniprot]
415
  ```
416
 
@@ -426,7 +457,7 @@ on UniProt.
426
  | Model | Dimension | Description |
427
  |-------|-----------|-------------|
428
  | ESM2 | 1280 | `esm2_t33_650M_UR50D` pretrained on UniRef50 |
429
- | Twin | 512 | Fine-tuned on Gene Ontology annotations |
430
 
431
  ### Comparison
432
 
@@ -444,11 +475,13 @@ this GO supervision.
444
  - Pretrained on UniRef50 with masked language modeling
445
  - General-purpose protein representation
446
 
447
- **Twin Network** (`aa_esm_resnik_1024_contrastive_padding`):
448
- - Two-tower contrastive encoder (AA-vocab Transformer + frozen ESM2 backbone)
449
- - Trained on Resnik GO-semantic similarity across all GO aspects (MF, BP, CC combined)
450
- - Output: `concat(custom_proj, esm_proj)` 512-dim embedding optimized for
451
- functional similarity (GO-coherent nearest neighbors)
 
 
452
 
453
  ### Gene Ontology
454
 
@@ -476,10 +509,10 @@ if __name__ == "__main__":
476
  _ = get_faiss()
477
  except Exception as e:
478
  print(f"FAISS load failed (will retry on first request): {e}")
479
- print(f"Loading Twin model from {TWIN_REPO_ID}...")
480
  try:
481
- _ = get_twin()
482
- print("Twin ready!")
483
  except Exception as e:
484
  print(f"Twin load failed (will retry on first request): {e}")
485
  demo.launch(
 
20
  # Model config
21
  ESM2_MODEL = "esm2_t33_650M_UR50D" # 650M params, 1280-dim
22
  ESM2_DIM = 1280
23
+ TWIN_DIM = 1024 # 2 * projection_dim (512), two-tower concat
24
 
25
  # FAISS index config (UniRef50 GO-annotated, ESM2 650M embeddings)
26
  FAISS_REPO_ID = os.environ.get("FAISS_REPO_ID", "genomenet/esm2-uniref50-faiss")
 
29
  FAISS_METADATA_FILE = "metadata.json"
30
  FAISS_NPROBE = int(os.environ.get("FAISS_NPROBE", "32"))
31
 
32
+ # Twin model config (3 aspect-specific checkpoints in one HF model repo)
33
+ TWIN_REPO_ID = os.environ.get("TWIN_REPO_ID", "genomenet/twin-point-1024")
34
+ TWIN_CHECKPOINT_FILES = {
35
+ "BP": "bp_cp_best.pt", # Biological Process
36
+ "CC": "cc_cp_best.pt", # Cellular Component
37
+ "MF": "mf_cp_best.pt", # Molecular Function
38
+ }
39
+ TWIN_DEFAULT_ASPECT = os.environ.get("TWIN_DEFAULT_ASPECT", "BP")
40
  TWIN_ESM_BACKBONE = os.environ.get("TWIN_ESM_BACKBONE", "facebook/esm2_t33_650M_UR50D")
41
 
42
  # Model cache
43
  _esm2_model = None
44
  _esm2_alphabet = None
45
+ # Only one aspect cached at a time (each Twin is ~2.7 GB on CPU, can't fit all 3 on a cpu-basic Space)
46
+ _twin_cache = {"aspect": None, "model": None, "seq_len": None}
47
  _faiss_index = None
48
  _faiss_ids = None
49
  _faiss_metadata = None
 
61
  print("ESM2 loaded.")
62
  return _esm2_model, _esm2_alphabet
63
 
64
+ def get_twin(aspect=None):
65
+ """Download + load the fine-tuned Twin model for the requested GO aspect.
 
 
 
66
 
67
+ Only one aspect is kept in memory at a time — switching aspects evicts the
68
+ previous model (each is ~2.7 GB; three won't fit on cpu-basic).
69
+ """
70
+ global _twin_cache
71
+ aspect = (aspect or TWIN_DEFAULT_ASPECT).upper()
72
+ if aspect not in TWIN_CHECKPOINT_FILES:
73
+ raise ValueError(f"Unknown aspect {aspect!r}; expected one of {list(TWIN_CHECKPOINT_FILES)}")
74
+
75
+ if _twin_cache["aspect"] == aspect and _twin_cache["model"] is not None:
76
+ return _twin_cache["model"], _twin_cache["seq_len"]
77
+
78
+ import gc
79
  import torch
80
  from huggingface_hub import hf_hub_download
81
  from twin_model import load_twin_model
82
 
83
+ # Evict any previously loaded aspect to free ~2.7 GB before loading the next.
84
+ if _twin_cache["model"] is not None:
85
+ print(f"Evicting Twin/{_twin_cache['aspect']} to load Twin/{aspect}...")
86
+ _twin_cache["model"] = None
87
+ _twin_cache["seq_len"] = None
88
+ _twin_cache["aspect"] = None
89
+ gc.collect()
90
+ if torch.cuda.is_available():
91
+ torch.cuda.empty_cache()
92
+
93
+ filename = TWIN_CHECKPOINT_FILES[aspect]
94
+ print(f"Downloading Twin/{aspect} checkpoint ({filename}) from {TWIN_REPO_ID}...")
95
+ ckpt_path = hf_hub_download(repo_id=TWIN_REPO_ID, filename=filename, repo_type="model")
96
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
  model, seq_len, emb_dim = load_twin_model(ckpt_path, device, TWIN_ESM_BACKBONE)
98
  if emb_dim != TWIN_DIM:
99
+ print(f" WARN Twin/{aspect} output dim is {emb_dim}, expected {TWIN_DIM}")
100
+
101
+ _twin_cache["aspect"] = aspect
102
+ _twin_cache["model"] = model
103
+ _twin_cache["seq_len"] = seq_len
104
+ return model, seq_len
105
 
106
  def get_faiss():
107
  """Download + load FAISS index and UniRef50 id mapping from HF dataset repo."""
 
177
  return embedding
178
 
179
  @torch.no_grad()
180
+ def embed_twin(sequence, aspect=None):
181
+ """Compute Twin embedding for the given GO aspect (BP/CC/MF)."""
182
  from twin_model import ensure_aa_sequence, preprocess_sequences_batch
183
 
184
+ model, seq_len = get_twin(aspect)
185
  device = next(model.parameters()).device
186
  cleaned = ensure_aa_sequence(sequence)
187
  input_ids = preprocess_sequences_batch([cleaned], seq_len, device)
 
320
  # Example protein (human insulin)
321
  EXAMPLE_PROTEIN = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
322
 
323
+ def process(sequence: str, top_k: int = 10, twin_aspect: str = "BP"):
324
  """Process protein sequence, compare embeddings, and search FAISS."""
325
  sequence = strip_fasta_header(sequence.strip())
326
 
 
331
 
332
  # Compute embeddings
333
  esm2_emb = embed_esm2(sequence)
334
+ twin_emb = embed_twin(sequence, aspect=twin_aspect)
335
 
336
  # Compute stats
337
  esm2_stats = compute_stats(esm2_emb)
 
348
 
349
  summary = f"""### Results
350
 
351
+ | | ESM2 | Twin ({twin_aspect}) |
352
  |---|---|---|
353
  | Dimension | {esm2_stats['dim']} | {twin_stats['dim']} |
354
  | L2 Norm | {esm2_stats['l2_norm']:.2f} | {twin_stats['l2_norm']:.2f} |
 
360
 
361
  # Create visualizations
362
  esm2_heatmap = create_embedding_heatmap(esm2_emb, "ESM2 Embedding")
363
+ twin_heatmap = create_embedding_heatmap(twin_emb, f"Twin Embedding ({twin_aspect})")
364
  comparison_plot = create_comparison_plot(esm2_stats, twin_stats)
365
  distribution_plot = create_distribution_plot(esm2_emb, twin_emb)
366
 
 
387
  minimum=1, maximum=50, value=10, step=1,
388
  label="Nearest neighbors (top-k)"
389
  )
390
+ twin_aspect_radio = gr.Radio(
391
+ choices=["BP", "CC", "MF"],
392
+ value=TWIN_DEFAULT_ASPECT,
393
+ label="Twin GO aspect",
394
+ info="Biological Process (BP), Cellular Component (CC), or Molecular Function (MF). "
395
+ "First switch loads the aspect's checkpoint (~15 s)."
396
+ )
397
  btn = gr.Button("compare embeddings", variant="primary")
398
  output = gr.Markdown()
399
  with gr.Row():
 
406
 
407
  with gr.Row():
408
  esm2_heatmap = gr.Plot(label="ESM2 Embedding (1280-dim)")
409
+ twin_heatmap = gr.Plot(label="Twin Embedding (1024-dim)")
410
 
411
  gr.Markdown("### Nearest UniRef50 neighbors (ESM2 embedding, cosine)")
412
  hits_table = gr.Dataframe(
 
418
 
419
  btn.click(
420
  process,
421
+ inputs=[seq_input, top_k_slider, twin_aspect_radio],
422
  outputs=[output, esm2_download, twin_download, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_table],
423
  api_name="compare"
424
  )
 
435
  result = client.predict(
436
  sequence="MALWMRLLPLLALLALWG...", # protein sequence
437
  top_k=10,
438
+ twin_aspect="BP", # "BP" | "CC" | "MF"
439
  api_name="/compare"
440
  )
441
 
442
  summary, esm2_path, twin_path, *plots, hits = result
443
  esm2_emb = np.load(esm2_path) # (1280,)
444
+ twin_emb = np.load(twin_path) # (1024,)
445
  # hits: DataFrame with columns [rank, uniref50_id, cosine, uniprot]
446
  ```
447
 
 
457
  | Model | Dimension | Description |
458
  |-------|-----------|-------------|
459
  | ESM2 | 1280 | `esm2_t33_650M_UR50D` pretrained on UniRef50 |
460
+ | Twin | 1024 | Resnik-contrastive fine-tune; one checkpoint per GO aspect (BP/CC/MF) |
461
 
462
  ### Comparison
463
 
 
475
  - Pretrained on UniRef50 with masked language modeling
476
  - General-purpose protein representation
477
 
478
+ **Twin Network** (`train_point_{BP,CC,MF}_20251221_std_ft_bs32ga4`):
479
+ - Two-tower contrastive encoder: custom AA Transformer + **fine-tuned** ESM2 backbone
480
+ - **One checkpoint per GO aspect**: Biological Process (BP), Cellular Component (CC),
481
+ Molecular Function (MF). Pick aspect via the Twin GO aspect radio button.
482
+ - Trained on Resnik GO-semantic similarity within each aspect
483
+ - Output: `concat(custom_proj, esm_proj)` → 1024-dim; L2 distance on L2-normalized
484
+ embeddings ≈ functional distance in that aspect
485
 
486
  ### Gene Ontology
487
 
 
509
  _ = get_faiss()
510
  except Exception as e:
511
  print(f"FAISS load failed (will retry on first request): {e}")
512
+ print(f"Loading default Twin aspect ({TWIN_DEFAULT_ASPECT}) from {TWIN_REPO_ID}...")
513
  try:
514
+ _ = get_twin(TWIN_DEFAULT_ASPECT)
515
+ print(f"Twin/{TWIN_DEFAULT_ASPECT} ready!")
516
  except Exception as e:
517
  print(f"Twin load failed (will retry on first request): {e}")
518
  demo.launch(
twin_model.py CHANGED
@@ -17,7 +17,8 @@ similarity (folder name: `aa_esm_resnik_1024_contrastive_padding_1gpu_old`):
17
 
18
  Final embedding = concat(custom_proj, esm_proj) with shape (2 * projection_dim,)
19
  Output size is `2 * projection_dim`, read from the checkpoint's `args` at load time.
20
- For the `model_aa_best.pt` checkpoint this is **512** (projection_dim=256).
 
21
  """
22
 
23
  from __future__ import annotations
 
17
 
18
  Final embedding = concat(custom_proj, esm_proj) with shape (2 * projection_dim,)
19
  Output size is `2 * projection_dim`, read from the checkpoint's `args` at load time.
20
+ For the `train_point_{BP,CC,MF}_20251221_std_ft_bs32ga4/cp_best.pt` checkpoints
21
+ (current default family) this is **1024** (projection_dim=512).
22
  """
23
 
24
  from __future__ import annotations