bertose-affinose-inference / scripts /run_release_smoke_test.py
supanthadey1's picture
Add headless release smoke test
9729ebd verified
#!/usr/bin/env python3
"""Headless BERTose/AFFINose Hugging Face release smoke test.
This script is meant for cluster or cloud notebook validation. It downloads the
public release repositories without a user token, runs single and batch BERTose
embedding, BERTose IAR, and AFFINose protein-glycan scoring, then writes CSV and
JSON outputs with timing.
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
import time
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from tqdm.auto import tqdm
REPOS = {
"encoder": "supanthadey1/bertose-glycan-encoder",
"resolver": "supanthadey1/bertose-iar-resolver",
"affinose": "supanthadey1/affinose-interaction-model",
"notebook": "supanthadey1/bertose-affinose-inference",
}
DEFAULT_WURCS = [
"WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1",
"WURCS=2.0/2,4,3/[a1221m-1b_1-5][a1221m-1a_1-5]/1-2-2-2/a3-b1_b3-c1_c3-d1",
"WURCS=2.0/3,3,2/[a2112h-1b_1-5_2*NCC/3=O_4*OSO/3=O/3=O_6*OSO/3=O/3=O][a2122A-1b_1-5][a1221m-1a_1-5_2*OSO/3=O/3=O_4*OSO/3=O/3=O]/1-2-3/a3-b1_b3-c1",
"WURCS=2.0/2,2,1/[a2122h-1a_1-5_2*OSO/3=O/3=O_3*OSO/3=O/3=O_4*OSO/3=O/3=O][a2122m-1a_1-5_2*OSO/3=O/3=O_3*OSO/3=O/3=O]/1-2/a6-b1",
"WURCS=2.0/2,3,2/[a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O]/1-1-2/a4-b1_b3-c1",
"WURCS=2.0/3,3,2/[a2122A-1b_1-5][a211h-1a_1-5][a2122h-1b_1-5]/1-2-3/a2-b1_b2-c1",
"WURCS=2.0/2,4,3/[a2122h-1a_1-5][a2112h-1a_1-5]/1-1-2-2/a4-b1_b4-c1_b6-d1",
"WURCS=2.0/2,3,2/[a2112h-1b_1-5][a2112h-1a_1-4]/1-2-2/a4-b1_b2-c1",
"WURCS=2.0/1,1,0/[a122h-1a_1-5_2*OC_3*OC_4*OC]/1/",
"WURCS=2.0/1,1,0/[a122h-1a_1-4_2*OCC/3=O_3*OCC/3=O]/1/",
"WURCS=2.0/1,1,0/[had22111m-2b_2-6_5*N_7*N]/1/",
"WURCS=2.0/1,1,0/[Aad22111m-2b_2-6_5*N_7*N]/1/",
"WURCS=2.0/4,7,6/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_5*N_7*N]/1-2-1-2-3-4-3/a3-b1_b3-c1_b6-g1_c3-d1_d6-e1_e6-f2",
"WURCS=2.0/4,7,6/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_7*N_5*N]/1-2-3-1-2-3-4/a3-b1_b3-d1_b6-c1_d3-e1_e6-f1_f6-g2",
"WURCS=2.0/1,1,0/[Aad22111m-2a_2-6_5*N_7*N]/1/",
"WURCS=2.0/5,8,7/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2a_2-6_5*N_7*N][Aad22111m-2b_2-6_5*N_7*N]/1-2-1-2-3-4-3-5/a3-b1_b3-c1_b6-g1_c3-d1_d6-e1_e6-f2_g6-h2",
"WURCS=2.0/5,8,7/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_7*N_5*N][Aad22111m-2a_2-6_7*N_5*N]/1-2-3-4-1-2-3-5/a3-b1_b3-e1_b6-c1_c6-d2_e3-f1_f6-g1_g6-h2",
"WURCS=2.0/3,3,2/[a2112m-1b_1-5_2*NCC/3=O][a212h-1b_1-5][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a3-b1_b4-c2",
"WURCS=2.0/3,3,2/[a212h-1b_1-5][a2112m-1b_1-5_2*NCC/3=O][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a4-b1_b3-c2",
"WURCS=2.0/3,3,2/[a212h-1b_1-5][a2112m-1b_1-5_2*NCC/3=O_4*OCC/3=O][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a4-b1_b3-c2",
]
DEFAULT_AMBIGUOUS_WURCS = [
"WURCS=2.0/2,2,1/[u2112h][a1221m-1a_1-5]/1-2/a2-b1",
"WURCS=2.0/3,3,2/[u2112h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1",
"WURCS=2.0/4,6,5/[u2122h][a2112h-1b_1-5][a1221m-1a_1-5][a2112h-1a_1-5_2*NCC/3=O]/1-2-3-4-2-3/a4-b1_b2-c1_b3-d1_d3-e1_e2-f1",
"WURCS=2.0/4,6,5/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5][a2112h-1a_1-5_2*NCC/3=O]/1-2-3-4-2-3/a4-b1_b2-c1_b3-d1_d3-e1_e2-f1",
"WURCS=2.0/4,4,3/[u2112h][a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4/a3-b1_b3-c1_c2-d1",
"WURCS=2.0/5,6,5/[u2122h][a2112h-1b_1-5][a2112h-1a_1-5][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-4-2-5/a4-b1_b4-c1_c3-d1_d3-e1_e2-f1",
"WURCS=2.0/4,5,4/[u2122h][a2112h-1b_1-5][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4/a4-b1_b3-c1_c3-d1_d2-e1",
"WURCS=2.0/5,6,5/[u2122h][a2112h-1b_1-5][Aad21122h-2a_2-6_5*NCC/3=O][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-4-2-5/a4-b1_b3-c2_b4-d1_d3-e1_e2-f1",
"WURCS=2.0/3,3,2/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1",
"WURCS=2.0/6,11,10/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-4-2-5-6/a4-b1_b4-c1_c3-d1_c6-h1_d2-e1_e3-f1_f2-g1_h2-i1_i3-j1_j2-k1",
"WURCS=2.0/6,12,11/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-4-2-5-6-6/a4-b1_a6-l1_b4-c1_c3-d1_c6-h1_d2-e1_e3-f1_f2-g1_h2-i1_i3-j1_j2-k1",
"WURCS=2.0/4,5,4/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4/a4-b1_b3-c1_c3-d1_d2-e1",
"WURCS=2.0/4,4,3/[u2112h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4/a3-b1_b3-c1_c2-d1",
"WURCS=2.0/4,7,6/[u2112h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-3-4/a3-b1_a6-e1_b3-c1_c2-d1_e3-f1_f2-g1",
"WURCS=2.0/4,8,7/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4-3-4-2/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f3-g1_f4-h1",
"WURCS=2.0/4,7,6/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4-3-2/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f4-g1",
"WURCS=2.0/5,8,7/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5][Aad21122h-2a_2-6_5*NCC/3=O]/1-2-3-2-4-3-2-5/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f4-g1_g6-h2",
"WURCS=2.0/3,3,2/[u2122h_2*NCC/3=O_6*OSO/3=O/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1",
"WURCS=2.0/3,4,3/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-3/a3-b1_a4-d1_b2-c1",
"WURCS=2.0/6,13,12/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-6-4-2-5-6-6/a4-b1_b4-c1_c3-d1_c6-i1_d2-e1_e3-f1_e4-h1_f2-g1_i2-j1_j3-k1_j4-m1_k2-l1",
]
def get_device(name: str) -> torch.device:
if name == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(name)
def ensure_writable_hf_home(cache_dir: Path) -> None:
"""Keep Hugging Face/Xet cache writes inside the selected work directory."""
fallback = cache_dir.parent / "hf_home"
current = os.environ.get("HF_HOME")
candidates = [Path(current).expanduser()] if current else []
candidates.append(fallback)
for candidate in candidates:
try:
candidate.mkdir(parents=True, exist_ok=True)
probe = candidate / ".write_probe"
probe.write_text("ok", encoding="utf-8")
probe.unlink()
except OSError:
continue
os.environ["HF_HOME"] = str(candidate)
os.environ.setdefault("HF_HUB_CACHE", str(candidate / "hub"))
os.environ.setdefault("HF_XET_CACHE", str(candidate / "xet"))
return
raise RuntimeError(f"Could not create a writable Hugging Face cache under {fallback}")
def download_repos(cache_dir: Path) -> Dict[str, Path]:
cache_dir.mkdir(parents=True, exist_ok=True)
ensure_writable_hf_home(cache_dir)
paths: Dict[str, Path] = {}
for key, repo_id in REPOS.items():
local_dir = cache_dir / repo_id.split("/")[-1]
paths[key] = Path(
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=str(local_dir),
token=False,
)
)
return paths
def add_source_paths(paths: Dict[str, Path]) -> None:
for key in ["affinose", "encoder", "resolver"]:
src = paths[key] / "src"
if str(src) not in sys.path:
sys.path.insert(0, str(src))
os.environ["BERTOSE_ROOT"] = str(paths["encoder"])
os.environ["BERTOSE_REPO_ROOT"] = str(paths["encoder"])
def count_layers(state_dict: Dict[str, torch.Tensor], prefix: str, default: int) -> int:
pattern = re.compile(rf"^{re.escape(prefix)}\.(\d+)\.")
indices = [int(match.group(1)) for key in state_dict for match in [pattern.match(key)] if match]
return max(indices) + 1 if indices else default
def import_bertose_classes():
from bertose_model import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
from wurcs_bpe_tokenizer import WURCSBPETokenizer
return MultimodalGlycanBERT, MultimodalGlycanBERTConfig, WURCSBPETokenizer
def build_config_from_state_dict(state_dict: Dict[str, torch.Tensor]):
_, MultimodalGlycanBERTConfig, _ = import_bertose_classes()
seq_vocab_size, seq_hidden = state_dict["seq_embeddings.token_embeddings.weight"].shape
seq_max_length = state_dict["seq_embeddings.position_embeddings.weight"].shape[0]
ms_vocab_size = 242
ms_hidden_size = 384
if "ms_embeddings.token_embeddings.weight" in state_dict:
ms_total_vocab, ms_hidden_size = state_dict["ms_embeddings.token_embeddings.weight"].shape
ms_vocab_size = max(1, ms_total_vocab - seq_vocab_size)
struct_vocab_size = 1024
struct_hidden_size = 512
use_3d = any(key.startswith("struct_embeddings.") for key in state_dict)
if "struct_embeddings.token_embeddings.weight" in state_dict:
struct_vocab_size, struct_hidden_size = state_dict["struct_embeddings.token_embeddings.weight"].shape
return MultimodalGlycanBERTConfig(
seq_vocab_size=seq_vocab_size,
seq_hidden_size=seq_hidden,
seq_num_layers=count_layers(state_dict, "seq_layers", 12),
seq_num_heads=12,
seq_max_length=seq_max_length,
ms_vocab_size=ms_vocab_size,
ms_hidden_size=ms_hidden_size,
ms_num_layers=count_layers(state_dict, "ms_layers", 6),
ms_num_heads=6,
ms_max_length=150,
struct_vocab_size=struct_vocab_size,
struct_hidden_size=struct_hidden_size,
struct_num_layers=count_layers(state_dict, "struct_layers", 8),
struct_num_heads=8,
struct_max_length=200,
use_3d=use_3d,
use_cross_attention=any(key.startswith("cross_attention.") for key in state_dict),
use_cnn_frontend=any(key.startswith("seq_embeddings.conv_layers.") for key in state_dict),
cnn_kernel_size=3,
)
def load_bertose_backbone(checkpoint_path: Path, device: torch.device):
MultimodalGlycanBERT, _, _ = import_bertose_classes()
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint)
state_dict = {key: value for key, value in state_dict.items() if not key.startswith("proj_head.")}
config = build_config_from_state_dict(state_dict)
model = MultimodalGlycanBERT(config)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
missing_report = [key for key in missing if not key.startswith(("dist_proj.", "distance_head."))]
unexpected_report = [key for key in unexpected if not key.startswith(("proj_head.", "distance_head."))]
if missing_report:
warnings.warn(f"{checkpoint_path.name}: {len(missing_report)} unexpected missing tensors")
if unexpected_report:
warnings.warn(f"{checkpoint_path.name}: {len(unexpected_report)} unexpected tensors")
model._dist_none_printed = True
model.to(device).eval()
return model, config
def load_tokenizer(vocab_path: Path):
_, _, WURCSBPETokenizer = import_bertose_classes()
return WURCSBPETokenizer(str(vocab_path))
def tensorize_tokenized(tokenized_rows: Sequence[Dict[str, Any]], device: torch.device) -> Dict[str, torch.Tensor]:
keys = ["token_ids", "attention_mask", "residue_ids", "branch_depths", "linkage_types"]
return {key: torch.tensor([row[key] for row in tokenized_rows], dtype=torch.long, device=device) for key in keys}
def nonpad_tokens(tokenizer, token_ids: Sequence[int], attention_mask: Sequence[int]) -> List[str]:
return [tokenizer.id_to_token.get(int(tid), "[UNK]") for tid, keep in zip(token_ids, attention_mask) if int(keep) == 1]
class BertoseEmbedder:
def __init__(self, checkpoint_path: Path, vocab_path: Path, device: torch.device):
self.device = device
self.tokenizer = load_tokenizer(vocab_path)
self.model, self.config = load_bertose_backbone(checkpoint_path, device=device)
@torch.no_grad()
def _encode_hidden(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
hidden = self.model.seq_embeddings(
batch["token_ids"],
branch_depths=batch["branch_depths"],
linkage_types=batch["linkage_types"],
)
for layer in self.model.seq_layers:
hidden = layer(hidden, batch["attention_mask"])
return hidden
@torch.no_grad()
def embed_batch(self, wurcs_list: Sequence[str], batch_size: int, pooling: str = "cls") -> np.ndarray:
embeddings: List[np.ndarray] = []
for start in tqdm(range(0, len(wurcs_list), batch_size), desc="Embedding glycans"):
chunk = list(wurcs_list[start : start + batch_size])
tokenized = [self.tokenizer.tokenize(wurcs, max_length=self.config.seq_max_length) for wurcs in chunk]
batch = tensorize_tokenized(tokenized, self.device)
hidden = self._encode_hidden(batch)
if pooling == "cls":
pooled = hidden[:, 0, :]
elif pooling == "mean":
mask = batch["attention_mask"].float().unsqueeze(-1)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
else:
raise ValueError("pooling must be 'cls' or 'mean'")
embeddings.append(pooled.detach().cpu().numpy().astype("float32"))
return np.vstack(embeddings)
def embed_single(self, wurcs: str, pooling: str = "cls") -> np.ndarray:
return self.embed_batch([wurcs], batch_size=1, pooling=pooling)[0]
class BertoseAmbiguityResolver:
def __init__(self, checkpoint_path: Path, vocab_path: Path, ambiguity_path: Path, device: torch.device, threshold: float = 0.80, top_k: int = 5):
self.device = device
self.threshold = threshold
self.top_k = top_k
self.tokenizer = load_tokenizer(vocab_path)
with open(ambiguity_path, encoding="utf-8") as handle:
ambiguity = json.load(handle)
self.ambiguous_ids = set(int(item) for item in ambiguity.get("ambiguous_ids", []))
self.ambiguous_ids.update(
int(token_id)
for token, token_id in self.tokenizer.token_to_id.items()
if token not in {"[PAD]", "[UNK]", "[START]", "[END]", "[MASK]"}
and ("?" in token or "x" in token or token.startswith("u"))
)
self.special_ids = {
int(token_id)
for token, token_id in self.tokenizer.token_to_id.items()
if token.startswith("[") and token.endswith("]")
}
self.invalid_prediction_ids = self.ambiguous_ids | self.special_ids
self.mask_id = self.tokenizer.token_to_id.get("[MASK]", 4)
self.model, self.config = load_bertose_backbone(checkpoint_path, device=device)
def _ambiguous_positions(self, tokenized: Dict[str, Any]) -> List[int]:
length = int(tokenized.get("length", sum(tokenized["attention_mask"])))
return [idx for idx, tid in enumerate(tokenized["token_ids"][:length]) if int(tid) in self.ambiguous_ids]
@torch.no_grad()
def _forward_logits(self, rows: Sequence[Dict[str, Any]], current_ids: Sequence[List[int]]) -> torch.Tensor:
batch_size = len(rows)
outputs = self.model(
seq_token_ids=torch.tensor(current_ids, dtype=torch.long, device=self.device),
seq_attention_mask=torch.tensor([row["attention_mask"] for row in rows], dtype=torch.long, device=self.device),
seq_residue_ids=torch.tensor([row["residue_ids"] for row in rows], dtype=torch.long, device=self.device),
seq_branch_depths=torch.tensor([row["branch_depths"] for row in rows], dtype=torch.long, device=self.device),
seq_linkage_types=torch.tensor([row["linkage_types"] for row in rows], dtype=torch.long, device=self.device),
ms_token_ids=torch.zeros(batch_size, 150, dtype=torch.long, device=self.device),
ms_attention_mask=torch.zeros(batch_size, 150, dtype=torch.long, device=self.device),
has_ms=torch.zeros(batch_size, dtype=torch.bool, device=self.device),
struct_token_ids=torch.zeros(batch_size, 200, dtype=torch.long, device=self.device),
struct_attention_mask=torch.zeros(batch_size, 200, dtype=torch.long, device=self.device),
struct_residue_ids=torch.full((batch_size, 200), -1, dtype=torch.long, device=self.device),
has_3d=torch.zeros(batch_size, dtype=torch.bool, device=self.device),
return_dict=True,
)
return outputs["seq_logits"]
def resolve_batch(self, wurcs_list: Sequence[str], ids: Optional[Sequence[str]], iterations: int, batch_size: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
ids = list(ids) if ids is not None else [str(i) for i in range(len(wurcs_list))]
summaries: List[Dict[str, Any]] = []
details: List[Dict[str, Any]] = []
for start in tqdm(range(0, len(wurcs_list), batch_size), desc="Resolving glycans"):
chunk_wurcs = list(wurcs_list[start : start + batch_size])
chunk_ids = list(ids[start : start + batch_size])
tokenized_rows = [self.tokenizer.tokenize(wurcs, max_length=self.config.seq_max_length) for wurcs in chunk_wurcs]
current_ids = [list(row["token_ids"]) for row in tokenized_rows]
unresolved = [set(self._ambiguous_positions(row)) for row in tokenized_rows]
initial_counts = [len(item) for item in unresolved]
accepted_by_row: List[List[Dict[str, Any]]] = [[] for _ in tokenized_rows]
for iteration in range(1, iterations + 1):
masked_ids = []
for row_idx, row_ids in enumerate(current_ids):
masked = list(row_ids)
for pos in unresolved[row_idx]:
masked[pos] = self.mask_id
masked_ids.append(masked)
logits = self._forward_logits(tokenized_rows, masked_ids)
progress = 0
for row_idx, positions in enumerate(unresolved):
for pos in sorted(list(positions)):
probs = F.softmax(logits[row_idx, pos], dim=-1)
top_probs, top_ids = torch.topk(probs, k=self.top_k)
top_ids_list = [int(item) for item in top_ids.detach().cpu().tolist()]
top_probs_list = [float(item) for item in top_probs.detach().cpu().tolist()]
top_tokens = [self.tokenizer.id_to_token.get(item, "[UNK]") for item in top_ids_list]
accepted_candidate = next(
(
(candidate_id, candidate_prob)
for candidate_id, candidate_prob in zip(top_ids_list, top_probs_list)
if candidate_prob >= self.threshold and candidate_id not in self.invalid_prediction_ids
),
None,
)
detail = {
"sample_id": chunk_ids[row_idx],
"wurcs": chunk_wurcs[row_idx],
"iteration": iteration,
"position": pos,
"original_id": int(tokenized_rows[row_idx]["token_ids"][pos]),
"original_token": self.tokenizer.id_to_token.get(int(tokenized_rows[row_idx]["token_ids"][pos]), "[UNK]"),
"top_ids": json.dumps(top_ids_list),
"top_tokens": json.dumps(top_tokens),
"top_probs": json.dumps(top_probs_list),
"accepted": False,
}
if accepted_candidate is not None:
accepted_id, accepted_prob = accepted_candidate
current_ids[row_idx][pos] = accepted_id
positions.remove(pos)
progress += 1
detail.update(
{
"accepted": True,
"resolved_id": accepted_id,
"resolved_token": self.tokenizer.id_to_token.get(accepted_id, "[UNK]"),
"confidence": accepted_prob,
}
)
accepted_by_row[row_idx].append(detail.copy())
details.append(detail)
if progress == 0:
break
for row_idx, row in enumerate(tokenized_rows):
summaries.append(
{
"sample_id": chunk_ids[row_idx],
"wurcs": chunk_wurcs[row_idx],
"initial_ambiguous_tokens": initial_counts[row_idx],
"resolved_tokens": len(accepted_by_row[row_idx]),
"remaining_ambiguous_tokens": len(unresolved[row_idx]),
"final_token_sequence": " ".join(nonpad_tokens(self.tokenizer, current_ids[row_idx], row["attention_mask"])),
"accepted_updates_json": json.dumps(accepted_by_row[row_idx]),
}
)
return pd.DataFrame(summaries), pd.DataFrame(details)
def resolve_single(self, wurcs: str, iterations: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self.resolve_batch([wurcs], ids=["single"], iterations=iterations, batch_size=1)
def read_wurcs_csv(path: Optional[Path], default: Sequence[str], column: str, n: int) -> List[str]:
if path and path.exists():
df = pd.read_csv(path)
if column not in df.columns:
raise ValueError(f"{path} must contain column {column!r}")
vals = df[column].dropna().astype(str).tolist()
else:
vals = list(default)
vals = [item for item in vals if item.startswith("WURCS=")]
if len(vals) < n:
raise ValueError(f"Need at least {n} WURCS strings, found {len(vals)}")
return vals[:n]
def choose_affinose_pairs(
csv_path: Optional[Path],
protein_h5: Optional[Path],
output_dir: Path,
n: int,
default_wurcs: Sequence[str],
) -> Tuple[pd.DataFrame, Path, str]:
if protein_h5 and protein_h5.exists():
with h5py.File(protein_h5, "r") as handle:
available = set(handle.keys())
if csv_path and csv_path.exists():
df = pd.read_csv(csv_path, usecols=lambda col: col in {"protein_id", "glycan_wurcs"})
df = df.dropna(subset=["protein_id", "glycan_wurcs"]).copy()
df["protein_id"] = df["protein_id"].astype(str)
df["glycan_wurcs"] = df["glycan_wurcs"].astype(str)
df = df[df["protein_id"].isin(available) & df["glycan_wurcs"].str.startswith("WURCS=")]
df = df.drop_duplicates(["protein_id", "glycan_wurcs"])
diverse = df.drop_duplicates("protein_id").head(n)
if len(diverse) < n:
extra = df[~df.index.isin(diverse.index)].head(n - len(diverse))
diverse = pd.concat([diverse, extra], ignore_index=False)
pairs = diverse.head(n).copy()
else:
keys = sorted(available)[:n]
pairs = pd.DataFrame({"protein_id": keys, "glycan_wurcs": list(default_wurcs)[:n]})
if len(pairs) < n:
raise ValueError(f"Need {n} AFFINose pairs with available protein embeddings, found {len(pairs)}")
subset_h5 = output_dir / "affinose_subset_esmc_embeddings.h5"
with h5py.File(protein_h5, "r") as src, h5py.File(subset_h5, "w") as dst:
for protein_id in pairs["protein_id"].astype(str).unique():
dst.create_dataset(protein_id, data=src[protein_id][:], compression="gzip")
return pairs.reset_index(drop=True), subset_h5, "real_subset"
synthetic_h5 = output_dir / "synthetic_esmc_embeddings.h5"
rng = np.random.default_rng(123)
protein_ids = [f"smoke_protein_{idx}" for idx in range(n)]
with h5py.File(synthetic_h5, "w") as dst:
for idx, protein_id in enumerate(protein_ids):
dst.create_dataset(protein_id, data=rng.normal(size=(64 + idx, 960)).astype("float32"))
pairs = pd.DataFrame({"protein_id": protein_ids, "glycan_wurcs": list(default_wurcs)[:n]})
return pairs, synthetic_h5, "synthetic"
def write_embedding_csv(path: Path, sample_ids: Sequence[str], wurcs: Sequence[str], embeddings: np.ndarray) -> None:
emb_cols = [f"emb_{idx:03d}" for idx in range(embeddings.shape[1])]
df = pd.concat(
[
pd.DataFrame({"sample_id": sample_ids, "wurcs": wurcs}),
pd.DataFrame(embeddings, columns=emb_cols),
],
axis=1,
)
df.to_csv(path, index=False)
def finite_or_raise(values: Sequence[float], name: str) -> None:
arr = np.asarray(values, dtype="float64")
if not np.isfinite(arr).all():
raise RuntimeError(f"{name} contains non-finite values")
def main() -> None:
parser = argparse.ArgumentParser(description="Validate public BERTose/AFFINose Hugging Face inference.")
parser.add_argument("--output-dir", type=Path, default=Path("release_smoke_outputs"))
parser.add_argument("--hf-cache-dir", type=Path, default=Path("hf_release_cache"))
parser.add_argument("--device", default="auto")
parser.add_argument("--embedding-csv", type=Path, default=None)
parser.add_argument("--iar-csv", type=Path, default=None)
parser.add_argument("--affinose-csv", type=Path, default=None)
parser.add_argument("--protein-emb-h5", type=Path, default=None)
parser.add_argument("--embedding-n", type=int, default=20)
parser.add_argument("--iar-n", type=int, default=20)
parser.add_argument("--affinose-n", type=int, default=5)
parser.add_argument("--embedding-batch-size", type=int, default=20)
parser.add_argument("--iar-batch-size", type=int, default=8)
parser.add_argument("--affinose-batch-size", type=int, default=5)
parser.add_argument("--iar-iterations", type=int, default=3)
args = parser.parse_args()
torch.set_grad_enabled(False)
args.output_dir.mkdir(parents=True, exist_ok=True)
device = get_device(args.device)
summary: Dict[str, Any] = {
"device": str(device),
"cuda_available": torch.cuda.is_available(),
"torch_version": torch.__version__,
"repo_ids": REPOS,
"timing_seconds": {},
}
t0 = time.perf_counter()
paths = download_repos(args.hf_cache_dir)
summary["timing_seconds"]["download_repos"] = time.perf_counter() - t0
summary["repo_paths"] = {key: str(path) for key, path in paths.items()}
add_source_paths(paths)
embedding_wurcs = read_wurcs_csv(args.embedding_csv, DEFAULT_WURCS, "wurcs", args.embedding_n)
iar_wurcs = read_wurcs_csv(args.iar_csv, DEFAULT_AMBIGUOUS_WURCS, "wurcs", args.iar_n)
encoder_ckpt = paths["encoder"] / "checkpoints" / "bertose_glycan_encoder.pt"
encoder_vocab = paths["encoder"] / "vocab" / "bpe_vocabulary.json"
resolver_ckpt = paths["resolver"] / "checkpoints" / "bertose_iar_resolver.pt"
resolver_vocab = paths["resolver"] / "vocab" / "bpe_vocabulary.json"
ambiguity_path = paths["resolver"] / "vocab" / "bpe_ambiguity_tokens.json"
t0 = time.perf_counter()
embedder = BertoseEmbedder(encoder_ckpt, encoder_vocab, device)
summary["timing_seconds"]["load_bertose_embedder"] = time.perf_counter() - t0
t0 = time.perf_counter()
single_embedding = embedder.embed_single(embedding_wurcs[0])
summary["timing_seconds"]["bertose_embedding_single"] = time.perf_counter() - t0
if single_embedding.shape != (768,):
raise RuntimeError(f"Expected single embedding shape (768,), got {single_embedding.shape}")
finite_or_raise(single_embedding, "single_embedding")
t0 = time.perf_counter()
batch_embeddings = embedder.embed_batch(embedding_wurcs, batch_size=args.embedding_batch_size)
elapsed = time.perf_counter() - t0
summary["timing_seconds"]["bertose_embedding_batch"] = elapsed
summary["bertose_embedding_batch_shape"] = list(batch_embeddings.shape)
summary["bertose_embedding_items_per_second"] = len(embedding_wurcs) / max(elapsed, 1e-9)
if batch_embeddings.shape != (len(embedding_wurcs), 768):
raise RuntimeError(f"Expected batch embedding shape ({len(embedding_wurcs)}, 768), got {batch_embeddings.shape}")
finite_or_raise(batch_embeddings.ravel(), "batch_embeddings")
write_embedding_csv(args.output_dir / "bertose_embeddings_20.csv", [f"glycan_{idx:02d}" for idx in range(len(embedding_wurcs))], embedding_wurcs, batch_embeddings)
del embedder
if device.type == "cuda":
torch.cuda.empty_cache()
t0 = time.perf_counter()
resolver = BertoseAmbiguityResolver(resolver_ckpt, resolver_vocab, ambiguity_path, device)
summary["timing_seconds"]["load_bertose_iar"] = time.perf_counter() - t0
t0 = time.perf_counter()
single_summary, single_details = resolver.resolve_single(iar_wurcs[0], iterations=args.iar_iterations)
summary["timing_seconds"]["bertose_iar_single"] = time.perf_counter() - t0
if single_summary.empty:
raise RuntimeError("BERTose IAR single output is empty")
single_summary.to_csv(args.output_dir / "bertose_iar_single_summary.csv", index=False)
single_details.to_csv(args.output_dir / "bertose_iar_single_details.csv", index=False)
t0 = time.perf_counter()
iar_summary, iar_details = resolver.resolve_batch(
iar_wurcs,
ids=[f"ambiguous_{idx:02d}" for idx in range(len(iar_wurcs))],
iterations=args.iar_iterations,
batch_size=args.iar_batch_size,
)
elapsed = time.perf_counter() - t0
summary["timing_seconds"]["bertose_iar_batch"] = elapsed
summary["bertose_iar_batch_rows"] = len(iar_summary)
summary["bertose_iar_detail_rows"] = len(iar_details)
summary["bertose_iar_items_per_second"] = len(iar_wurcs) / max(elapsed, 1e-9)
if len(iar_summary) != len(iar_wurcs):
raise RuntimeError(f"Expected {len(iar_wurcs)} IAR summary rows, got {len(iar_summary)}")
iar_summary.to_csv(args.output_dir / "bertose_iar_batch_summary.csv", index=False)
iar_details.to_csv(args.output_dir / "bertose_iar_batch_details.csv", index=False)
del resolver
if device.type == "cuda":
torch.cuda.empty_cache()
pairs, protein_h5, protein_source = choose_affinose_pairs(
args.affinose_csv,
args.protein_emb_h5,
args.output_dir,
args.affinose_n,
embedding_wurcs,
)
pairs_path = args.output_dir / "affinose_pairs.csv"
pairs.to_csv(pairs_path, index=False)
summary["affinose_protein_embedding_source"] = protein_source
summary["affinose_pairs"] = len(pairs)
t0 = time.perf_counter()
from affinose_inference import AffinosePredictor
predictor = AffinosePredictor(
checkpoint_path=str(paths["affinose"] / "checkpoints" / "affinose_interaction_model.pt"),
bertose_checkpoint=str(encoder_ckpt),
vocab_path=str(paths["affinose"] / "vocab" / "bpe_vocabulary.json"),
protein_emb_path=str(protein_h5),
device=str(device),
)
summary["timing_seconds"]["load_affinose"] = time.perf_counter() - t0
t0 = time.perf_counter()
first = pairs.iloc[0]
single_score = predictor.predict_single(first["glycan_wurcs"], first["protein_id"])
summary["timing_seconds"]["affinose_single"] = time.perf_counter() - t0
finite_or_raise([single_score], "affinose_single_score")
summary["affinose_single_score"] = float(single_score)
t0 = time.perf_counter()
scores = predictor.predict_batch(
pairs["glycan_wurcs"].astype(str).tolist(),
pairs["protein_id"].astype(str).tolist(),
batch_size=args.affinose_batch_size,
)
elapsed = time.perf_counter() - t0
finite_or_raise(scores, "affinose_batch_scores")
summary["timing_seconds"]["affinose_batch"] = elapsed
summary["affinose_items_per_second"] = len(scores) / max(elapsed, 1e-9)
out_pairs = pairs.copy()
out_pairs["affinose_score"] = [float(score) for score in scores]
out_pairs.to_csv(args.output_dir / "affinose_predictions_5.csv", index=False)
summary["status"] = "ok"
summary_path = args.output_dir / "smoke_summary.json"
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()