Spaces:
Sleeping
Sleeping
Switch Twin to aspect-specific checkpoints with runtime switcher
Browse filesReplaced the old collapsed-training Twin checkpoint with the newer BP/CC/MF
aspect-specific family (train_point_{ASPECT}_20251221_std_ft_bs32ga4), which
has real dynamic range in its distance output (cos 0.59-0.96 vs old 0.999).
- Three checkpoints hosted in one HF model repo: genomenet/twin-point-1024
- Added BP/CC/MF radio button to the Compare tab (default BP)
- Lazy-load: only one aspect cached at a time (memory constraint on cpu-basic)
- Each aspect checkpoint contains fine-tuned ESM2 backbone (~2.7 GB per file)
- Twin output dim back to 1024 (these use projection_dim=512, unlike the old one)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- README.md +14 -7
- app.py +73 -40
- twin_model.py +2 -1
README.md
CHANGED
|
@@ -17,7 +17,7 @@ Compare protein sequence embeddings from ESM2 (baseline) and Twin network (fine-
|
|
| 17 |
| Model | Parameters | Embedding | Training |
|
| 18 |
|-------|------------|-----------|----------|
|
| 19 |
| ESM2 | 650M | 1280-dim | Pretrained on UniRef50 |
|
| 20 |
-
| Twin | ESM2 + custom |
|
| 21 |
|
| 22 |
## Usage
|
| 23 |
|
|
@@ -37,12 +37,19 @@ env vars.
|
|
| 37 |
|
| 38 |
## Twin model
|
| 39 |
|
| 40 |
-
Two-tower contrastive encoder (AA
|
| 41 |
-
trained on Resnik GO-semantic similarity.
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
## Acknowledgements
|
| 48 |
|
|
|
|
| 17 |
| Model | Parameters | Embedding | Training |
|
| 18 |
|-------|------------|-----------|----------|
|
| 19 |
| ESM2 | 650M | 1280-dim | Pretrained on UniRef50 |
|
| 20 |
+
| Twin | ESM2 + custom | 1024-dim | Resnik-contrastive fine-tune, one checkpoint per GO aspect (BP/CC/MF) |
|
| 21 |
|
| 22 |
## Usage
|
| 23 |
|
|
|
|
| 37 |
|
| 38 |
## Twin model
|
| 39 |
|
| 40 |
+
Two-tower contrastive encoder (custom AA Transformer + fine-tuned ESM2 backbone),
|
| 41 |
+
trained on Resnik GO-semantic similarity. **Three aspect-specific checkpoints**
|
| 42 |
+
(~2.7 GB each) are hosted at
|
| 43 |
+
[`genomenet/twin-point-1024`](https://huggingface.co/genomenet/twin-point-1024):
|
| 44 |
+
|
| 45 |
+
- `bp_cp_best.pt` — Biological Process
|
| 46 |
+
- `cc_cp_best.pt` — Cellular Component
|
| 47 |
+
- `mf_cp_best.pt` — Molecular Function
|
| 48 |
+
|
| 49 |
+
The Space loads one aspect at a time (cpu-basic memory budget); switching aspects
|
| 50 |
+
evicts the previous model and downloads the next (~15 s from disk-cache after
|
| 51 |
+
first use). Override via `TWIN_REPO_ID` / `TWIN_DEFAULT_ASPECT` / `TWIN_ESM_BACKBONE`
|
| 52 |
+
env vars. Architecture code lives in `twin_model.py` alongside `app.py`.
|
| 53 |
|
| 54 |
## Acknowledgements
|
| 55 |
|
app.py
CHANGED
|
@@ -20,7 +20,7 @@ from plotly.subplots import make_subplots
|
|
| 20 |
# Model config
|
| 21 |
ESM2_MODEL = "esm2_t33_650M_UR50D" # 650M params, 1280-dim
|
| 22 |
ESM2_DIM = 1280
|
| 23 |
-
TWIN_DIM =
|
| 24 |
|
| 25 |
# FAISS index config (UniRef50 GO-annotated, ESM2 650M embeddings)
|
| 26 |
FAISS_REPO_ID = os.environ.get("FAISS_REPO_ID", "genomenet/esm2-uniref50-faiss")
|
|
@@ -29,16 +29,21 @@ FAISS_IDS_FILE = "ids.npy"
|
|
| 29 |
FAISS_METADATA_FILE = "metadata.json"
|
| 30 |
FAISS_NPROBE = int(os.environ.get("FAISS_NPROBE", "32"))
|
| 31 |
|
| 32 |
-
# Twin model config (
|
| 33 |
-
TWIN_REPO_ID = os.environ.get("TWIN_REPO_ID", "genomenet/twin-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
TWIN_ESM_BACKBONE = os.environ.get("TWIN_ESM_BACKBONE", "facebook/esm2_t33_650M_UR50D")
|
| 36 |
|
| 37 |
# Model cache
|
| 38 |
_esm2_model = None
|
| 39 |
_esm2_alphabet = None
|
| 40 |
-
|
| 41 |
-
|
| 42 |
_faiss_index = None
|
| 43 |
_faiss_ids = None
|
| 44 |
_faiss_metadata = None
|
|
@@ -56,29 +61,47 @@ def get_esm2():
|
|
| 56 |
print("ESM2 loaded.")
|
| 57 |
return _esm2_model, _esm2_alphabet
|
| 58 |
|
| 59 |
-
def get_twin():
|
| 60 |
-
"""Download + load the fine-tuned Twin model
|
| 61 |
-
global _twin_model, _twin_seq_len
|
| 62 |
-
if _twin_model is not None:
|
| 63 |
-
return _twin_model, _twin_seq_len
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
import torch
|
| 66 |
from huggingface_hub import hf_hub_download
|
| 67 |
from twin_model import load_twin_model
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
model, seq_len, emb_dim = load_twin_model(ckpt_path, device, TWIN_ESM_BACKBONE)
|
| 77 |
if emb_dim != TWIN_DIM:
|
| 78 |
-
print(f" WARN Twin output dim is {emb_dim}, expected {TWIN_DIM}")
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def get_faiss():
|
| 84 |
"""Download + load FAISS index and UniRef50 id mapping from HF dataset repo."""
|
|
@@ -154,11 +177,11 @@ def embed_esm2(sequence):
|
|
| 154 |
return embedding
|
| 155 |
|
| 156 |
@torch.no_grad()
|
| 157 |
-
def embed_twin(sequence):
|
| 158 |
-
"""Compute Twin embedding
|
| 159 |
from twin_model import ensure_aa_sequence, preprocess_sequences_batch
|
| 160 |
|
| 161 |
-
model, seq_len = get_twin()
|
| 162 |
device = next(model.parameters()).device
|
| 163 |
cleaned = ensure_aa_sequence(sequence)
|
| 164 |
input_ids = preprocess_sequences_batch([cleaned], seq_len, device)
|
|
@@ -297,7 +320,7 @@ def create_distribution_plot(esm2_emb, twin_emb):
|
|
| 297 |
# Example protein (human insulin)
|
| 298 |
EXAMPLE_PROTEIN = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
| 299 |
|
| 300 |
-
def process(sequence: str, top_k: int = 10):
|
| 301 |
"""Process protein sequence, compare embeddings, and search FAISS."""
|
| 302 |
sequence = strip_fasta_header(sequence.strip())
|
| 303 |
|
|
@@ -308,7 +331,7 @@ def process(sequence: str, top_k: int = 10):
|
|
| 308 |
|
| 309 |
# Compute embeddings
|
| 310 |
esm2_emb = embed_esm2(sequence)
|
| 311 |
-
twin_emb = embed_twin(sequence)
|
| 312 |
|
| 313 |
# Compute stats
|
| 314 |
esm2_stats = compute_stats(esm2_emb)
|
|
@@ -325,7 +348,7 @@ def process(sequence: str, top_k: int = 10):
|
|
| 325 |
|
| 326 |
summary = f"""### Results
|
| 327 |
|
| 328 |
-
| | ESM2 | Twin |
|
| 329 |
|---|---|---|
|
| 330 |
| Dimension | {esm2_stats['dim']} | {twin_stats['dim']} |
|
| 331 |
| L2 Norm | {esm2_stats['l2_norm']:.2f} | {twin_stats['l2_norm']:.2f} |
|
|
@@ -337,7 +360,7 @@ Sequence: {len(sequence)} aa
|
|
| 337 |
|
| 338 |
# Create visualizations
|
| 339 |
esm2_heatmap = create_embedding_heatmap(esm2_emb, "ESM2 Embedding")
|
| 340 |
-
twin_heatmap = create_embedding_heatmap(twin_emb, "Twin Embedding")
|
| 341 |
comparison_plot = create_comparison_plot(esm2_stats, twin_stats)
|
| 342 |
distribution_plot = create_distribution_plot(esm2_emb, twin_emb)
|
| 343 |
|
|
@@ -364,6 +387,13 @@ with gr.Blocks(
|
|
| 364 |
minimum=1, maximum=50, value=10, step=1,
|
| 365 |
label="Nearest neighbors (top-k)"
|
| 366 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
btn = gr.Button("compare embeddings", variant="primary")
|
| 368 |
output = gr.Markdown()
|
| 369 |
with gr.Row():
|
|
@@ -376,7 +406,7 @@ with gr.Blocks(
|
|
| 376 |
|
| 377 |
with gr.Row():
|
| 378 |
esm2_heatmap = gr.Plot(label="ESM2 Embedding (1280-dim)")
|
| 379 |
-
twin_heatmap = gr.Plot(label="Twin Embedding (
|
| 380 |
|
| 381 |
gr.Markdown("### Nearest UniRef50 neighbors (ESM2 embedding, cosine)")
|
| 382 |
hits_table = gr.Dataframe(
|
|
@@ -388,7 +418,7 @@ with gr.Blocks(
|
|
| 388 |
|
| 389 |
btn.click(
|
| 390 |
process,
|
| 391 |
-
inputs=[seq_input, top_k_slider],
|
| 392 |
outputs=[output, esm2_download, twin_download, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_table],
|
| 393 |
api_name="compare"
|
| 394 |
)
|
|
@@ -405,12 +435,13 @@ client = Client("genomenet/functional-distance")
|
|
| 405 |
result = client.predict(
|
| 406 |
sequence="MALWMRLLPLLALLALWG...", # protein sequence
|
| 407 |
top_k=10,
|
|
|
|
| 408 |
api_name="/compare"
|
| 409 |
)
|
| 410 |
|
| 411 |
summary, esm2_path, twin_path, *plots, hits = result
|
| 412 |
esm2_emb = np.load(esm2_path) # (1280,)
|
| 413 |
-
twin_emb = np.load(twin_path) # (
|
| 414 |
# hits: DataFrame with columns [rank, uniref50_id, cosine, uniprot]
|
| 415 |
```
|
| 416 |
|
|
@@ -426,7 +457,7 @@ on UniProt.
|
|
| 426 |
| Model | Dimension | Description |
|
| 427 |
|-------|-----------|-------------|
|
| 428 |
| ESM2 | 1280 | `esm2_t33_650M_UR50D` pretrained on UniRef50 |
|
| 429 |
-
| Twin |
|
| 430 |
|
| 431 |
### Comparison
|
| 432 |
|
|
@@ -444,11 +475,13 @@ this GO supervision.
|
|
| 444 |
- Pretrained on UniRef50 with masked language modeling
|
| 445 |
- General-purpose protein representation
|
| 446 |
|
| 447 |
-
**Twin Network** (`
|
| 448 |
-
- Two-tower contrastive encoder
|
| 449 |
-
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
| 452 |
|
| 453 |
### Gene Ontology
|
| 454 |
|
|
@@ -476,10 +509,10 @@ if __name__ == "__main__":
|
|
| 476 |
_ = get_faiss()
|
| 477 |
except Exception as e:
|
| 478 |
print(f"FAISS load failed (will retry on first request): {e}")
|
| 479 |
-
print(f"Loading Twin
|
| 480 |
try:
|
| 481 |
-
_ = get_twin()
|
| 482 |
-
print("Twin ready!")
|
| 483 |
except Exception as e:
|
| 484 |
print(f"Twin load failed (will retry on first request): {e}")
|
| 485 |
demo.launch(
|
|
|
|
| 20 |
# Model config
|
| 21 |
ESM2_MODEL = "esm2_t33_650M_UR50D" # 650M params, 1280-dim
|
| 22 |
ESM2_DIM = 1280
|
| 23 |
+
TWIN_DIM = 1024 # 2 * projection_dim (512), two-tower concat
|
| 24 |
|
| 25 |
# FAISS index config (UniRef50 GO-annotated, ESM2 650M embeddings)
|
| 26 |
FAISS_REPO_ID = os.environ.get("FAISS_REPO_ID", "genomenet/esm2-uniref50-faiss")
|
|
|
|
| 29 |
FAISS_METADATA_FILE = "metadata.json"
|
| 30 |
FAISS_NPROBE = int(os.environ.get("FAISS_NPROBE", "32"))
|
| 31 |
|
| 32 |
+
# Twin model config (3 aspect-specific checkpoints in one HF model repo)
|
| 33 |
+
TWIN_REPO_ID = os.environ.get("TWIN_REPO_ID", "genomenet/twin-point-1024")
|
| 34 |
+
TWIN_CHECKPOINT_FILES = {
|
| 35 |
+
"BP": "bp_cp_best.pt", # Biological Process
|
| 36 |
+
"CC": "cc_cp_best.pt", # Cellular Component
|
| 37 |
+
"MF": "mf_cp_best.pt", # Molecular Function
|
| 38 |
+
}
|
| 39 |
+
TWIN_DEFAULT_ASPECT = os.environ.get("TWIN_DEFAULT_ASPECT", "BP")
|
| 40 |
TWIN_ESM_BACKBONE = os.environ.get("TWIN_ESM_BACKBONE", "facebook/esm2_t33_650M_UR50D")
|
| 41 |
|
| 42 |
# Model cache
|
| 43 |
_esm2_model = None
|
| 44 |
_esm2_alphabet = None
|
| 45 |
+
# Only one aspect cached at a time (each Twin is ~2.7 GB on CPU, can't fit all 3 on a cpu-basic Space)
|
| 46 |
+
_twin_cache = {"aspect": None, "model": None, "seq_len": None}
|
| 47 |
_faiss_index = None
|
| 48 |
_faiss_ids = None
|
| 49 |
_faiss_metadata = None
|
|
|
|
| 61 |
print("ESM2 loaded.")
|
| 62 |
return _esm2_model, _esm2_alphabet
|
| 63 |
|
| 64 |
+
def get_twin(aspect=None):
|
| 65 |
+
"""Download + load the fine-tuned Twin model for the requested GO aspect.
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
Only one aspect is kept in memory at a time — switching aspects evicts the
|
| 68 |
+
previous model (each is ~2.7 GB; three won't fit on cpu-basic).
|
| 69 |
+
"""
|
| 70 |
+
global _twin_cache
|
| 71 |
+
aspect = (aspect or TWIN_DEFAULT_ASPECT).upper()
|
| 72 |
+
if aspect not in TWIN_CHECKPOINT_FILES:
|
| 73 |
+
raise ValueError(f"Unknown aspect {aspect!r}; expected one of {list(TWIN_CHECKPOINT_FILES)}")
|
| 74 |
+
|
| 75 |
+
if _twin_cache["aspect"] == aspect and _twin_cache["model"] is not None:
|
| 76 |
+
return _twin_cache["model"], _twin_cache["seq_len"]
|
| 77 |
+
|
| 78 |
+
import gc
|
| 79 |
import torch
|
| 80 |
from huggingface_hub import hf_hub_download
|
| 81 |
from twin_model import load_twin_model
|
| 82 |
|
| 83 |
+
# Evict any previously loaded aspect to free ~2.7 GB before loading the next.
|
| 84 |
+
if _twin_cache["model"] is not None:
|
| 85 |
+
print(f"Evicting Twin/{_twin_cache['aspect']} to load Twin/{aspect}...")
|
| 86 |
+
_twin_cache["model"] = None
|
| 87 |
+
_twin_cache["seq_len"] = None
|
| 88 |
+
_twin_cache["aspect"] = None
|
| 89 |
+
gc.collect()
|
| 90 |
+
if torch.cuda.is_available():
|
| 91 |
+
torch.cuda.empty_cache()
|
| 92 |
+
|
| 93 |
+
filename = TWIN_CHECKPOINT_FILES[aspect]
|
| 94 |
+
print(f"Downloading Twin/{aspect} checkpoint ({filename}) from {TWIN_REPO_ID}...")
|
| 95 |
+
ckpt_path = hf_hub_download(repo_id=TWIN_REPO_ID, filename=filename, repo_type="model")
|
| 96 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
model, seq_len, emb_dim = load_twin_model(ckpt_path, device, TWIN_ESM_BACKBONE)
|
| 98 |
if emb_dim != TWIN_DIM:
|
| 99 |
+
print(f" WARN Twin/{aspect} output dim is {emb_dim}, expected {TWIN_DIM}")
|
| 100 |
+
|
| 101 |
+
_twin_cache["aspect"] = aspect
|
| 102 |
+
_twin_cache["model"] = model
|
| 103 |
+
_twin_cache["seq_len"] = seq_len
|
| 104 |
+
return model, seq_len
|
| 105 |
|
| 106 |
def get_faiss():
|
| 107 |
"""Download + load FAISS index and UniRef50 id mapping from HF dataset repo."""
|
|
|
|
| 177 |
return embedding
|
| 178 |
|
| 179 |
@torch.no_grad()
|
| 180 |
+
def embed_twin(sequence, aspect=None):
|
| 181 |
+
"""Compute Twin embedding for the given GO aspect (BP/CC/MF)."""
|
| 182 |
from twin_model import ensure_aa_sequence, preprocess_sequences_batch
|
| 183 |
|
| 184 |
+
model, seq_len = get_twin(aspect)
|
| 185 |
device = next(model.parameters()).device
|
| 186 |
cleaned = ensure_aa_sequence(sequence)
|
| 187 |
input_ids = preprocess_sequences_batch([cleaned], seq_len, device)
|
|
|
|
| 320 |
# Example protein (human insulin)
|
| 321 |
EXAMPLE_PROTEIN = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
| 322 |
|
| 323 |
+
def process(sequence: str, top_k: int = 10, twin_aspect: str = "BP"):
|
| 324 |
"""Process protein sequence, compare embeddings, and search FAISS."""
|
| 325 |
sequence = strip_fasta_header(sequence.strip())
|
| 326 |
|
|
|
|
| 331 |
|
| 332 |
# Compute embeddings
|
| 333 |
esm2_emb = embed_esm2(sequence)
|
| 334 |
+
twin_emb = embed_twin(sequence, aspect=twin_aspect)
|
| 335 |
|
| 336 |
# Compute stats
|
| 337 |
esm2_stats = compute_stats(esm2_emb)
|
|
|
|
| 348 |
|
| 349 |
summary = f"""### Results
|
| 350 |
|
| 351 |
+
| | ESM2 | Twin ({twin_aspect}) |
|
| 352 |
|---|---|---|
|
| 353 |
| Dimension | {esm2_stats['dim']} | {twin_stats['dim']} |
|
| 354 |
| L2 Norm | {esm2_stats['l2_norm']:.2f} | {twin_stats['l2_norm']:.2f} |
|
|
|
|
| 360 |
|
| 361 |
# Create visualizations
|
| 362 |
esm2_heatmap = create_embedding_heatmap(esm2_emb, "ESM2 Embedding")
|
| 363 |
+
twin_heatmap = create_embedding_heatmap(twin_emb, f"Twin Embedding ({twin_aspect})")
|
| 364 |
comparison_plot = create_comparison_plot(esm2_stats, twin_stats)
|
| 365 |
distribution_plot = create_distribution_plot(esm2_emb, twin_emb)
|
| 366 |
|
|
|
|
| 387 |
minimum=1, maximum=50, value=10, step=1,
|
| 388 |
label="Nearest neighbors (top-k)"
|
| 389 |
)
|
| 390 |
+
twin_aspect_radio = gr.Radio(
|
| 391 |
+
choices=["BP", "CC", "MF"],
|
| 392 |
+
value=TWIN_DEFAULT_ASPECT,
|
| 393 |
+
label="Twin GO aspect",
|
| 394 |
+
info="Biological Process (BP), Cellular Component (CC), or Molecular Function (MF). "
|
| 395 |
+
"First switch loads the aspect's checkpoint (~15 s)."
|
| 396 |
+
)
|
| 397 |
btn = gr.Button("compare embeddings", variant="primary")
|
| 398 |
output = gr.Markdown()
|
| 399 |
with gr.Row():
|
|
|
|
| 406 |
|
| 407 |
with gr.Row():
|
| 408 |
esm2_heatmap = gr.Plot(label="ESM2 Embedding (1280-dim)")
|
| 409 |
+
twin_heatmap = gr.Plot(label="Twin Embedding (1024-dim)")
|
| 410 |
|
| 411 |
gr.Markdown("### Nearest UniRef50 neighbors (ESM2 embedding, cosine)")
|
| 412 |
hits_table = gr.Dataframe(
|
|
|
|
| 418 |
|
| 419 |
btn.click(
|
| 420 |
process,
|
| 421 |
+
inputs=[seq_input, top_k_slider, twin_aspect_radio],
|
| 422 |
outputs=[output, esm2_download, twin_download, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_table],
|
| 423 |
api_name="compare"
|
| 424 |
)
|
|
|
|
| 435 |
result = client.predict(
|
| 436 |
sequence="MALWMRLLPLLALLALWG...", # protein sequence
|
| 437 |
top_k=10,
|
| 438 |
+
twin_aspect="BP", # "BP" | "CC" | "MF"
|
| 439 |
api_name="/compare"
|
| 440 |
)
|
| 441 |
|
| 442 |
summary, esm2_path, twin_path, *plots, hits = result
|
| 443 |
esm2_emb = np.load(esm2_path) # (1280,)
|
| 444 |
+
twin_emb = np.load(twin_path) # (1024,)
|
| 445 |
# hits: DataFrame with columns [rank, uniref50_id, cosine, uniprot]
|
| 446 |
```
|
| 447 |
|
|
|
|
| 457 |
| Model | Dimension | Description |
|
| 458 |
|-------|-----------|-------------|
|
| 459 |
| ESM2 | 1280 | `esm2_t33_650M_UR50D` pretrained on UniRef50 |
|
| 460 |
+
| Twin | 1024 | Resnik-contrastive fine-tune; one checkpoint per GO aspect (BP/CC/MF) |
|
| 461 |
|
| 462 |
### Comparison
|
| 463 |
|
|
|
|
| 475 |
- Pretrained on UniRef50 with masked language modeling
|
| 476 |
- General-purpose protein representation
|
| 477 |
|
| 478 |
+
**Twin Network** (`train_point_{BP,CC,MF}_20251221_std_ft_bs32ga4`):
|
| 479 |
+
- Two-tower contrastive encoder: custom AA Transformer + **fine-tuned** ESM2 backbone
|
| 480 |
+
- **One checkpoint per GO aspect**: Biological Process (BP), Cellular Component (CC),
|
| 481 |
+
Molecular Function (MF). Pick aspect via the Twin GO aspect radio button.
|
| 482 |
+
- Trained on Resnik GO-semantic similarity within each aspect
|
| 483 |
+
- Output: `concat(custom_proj, esm_proj)` → 1024-dim; L2 distance on L2-normalized
|
| 484 |
+
embeddings ≈ functional distance in that aspect
|
| 485 |
|
| 486 |
### Gene Ontology
|
| 487 |
|
|
|
|
| 509 |
_ = get_faiss()
|
| 510 |
except Exception as e:
|
| 511 |
print(f"FAISS load failed (will retry on first request): {e}")
|
| 512 |
+
print(f"Loading default Twin aspect ({TWIN_DEFAULT_ASPECT}) from {TWIN_REPO_ID}...")
|
| 513 |
try:
|
| 514 |
+
_ = get_twin(TWIN_DEFAULT_ASPECT)
|
| 515 |
+
print(f"Twin/{TWIN_DEFAULT_ASPECT} ready!")
|
| 516 |
except Exception as e:
|
| 517 |
print(f"Twin load failed (will retry on first request): {e}")
|
| 518 |
demo.launch(
|
twin_model.py
CHANGED
|
@@ -17,7 +17,8 @@ similarity (folder name: `aa_esm_resnik_1024_contrastive_padding_1gpu_old`):
|
|
| 17 |
|
| 18 |
Final embedding = concat(custom_proj, esm_proj) with shape (2 * projection_dim,)
|
| 19 |
Output size is `2 * projection_dim`, read from the checkpoint's `args` at load time.
|
| 20 |
-
For the `
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
from __future__ import annotations
|
|
|
|
| 17 |
|
| 18 |
Final embedding = concat(custom_proj, esm_proj) with shape (2 * projection_dim,)
|
| 19 |
Output size is `2 * projection_dim`, read from the checkpoint's `args` at load time.
|
| 20 |
+
For the `train_point_{BP,CC,MF}_20251221_std_ft_bs32ga4/cp_best.pt` checkpoints
|
| 21 |
+
(current default family) this is **1024** (projection_dim=512).
|
| 22 |
"""
|
| 23 |
|
| 24 |
from __future__ import annotations
|