| |
| """Headless BERTose/AFFINose Hugging Face release smoke test. |
| |
| This script is meant for cluster or cloud notebook validation. It downloads the |
| public release repositories without a user token, runs single and batch BERTose |
| embedding, BERTose IAR, and AFFINose protein-glycan scoring, then writes CSV and |
| JSON outputs with timing. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import sys |
| import time |
| import warnings |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Sequence, Tuple |
|
|
| import h5py |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import snapshot_download |
| from tqdm.auto import tqdm |
|
|
|
|
| REPOS = { |
| "encoder": "supanthadey1/bertose-glycan-encoder", |
| "resolver": "supanthadey1/bertose-iar-resolver", |
| "affinose": "supanthadey1/affinose-interaction-model", |
| "notebook": "supanthadey1/bertose-affinose-inference", |
| } |
|
|
|
|
| DEFAULT_WURCS = [ |
| "WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1", |
| "WURCS=2.0/2,4,3/[a1221m-1b_1-5][a1221m-1a_1-5]/1-2-2-2/a3-b1_b3-c1_c3-d1", |
| "WURCS=2.0/3,3,2/[a2112h-1b_1-5_2*NCC/3=O_4*OSO/3=O/3=O_6*OSO/3=O/3=O][a2122A-1b_1-5][a1221m-1a_1-5_2*OSO/3=O/3=O_4*OSO/3=O/3=O]/1-2-3/a3-b1_b3-c1", |
| "WURCS=2.0/2,2,1/[a2122h-1a_1-5_2*OSO/3=O/3=O_3*OSO/3=O/3=O_4*OSO/3=O/3=O][a2122m-1a_1-5_2*OSO/3=O/3=O_3*OSO/3=O/3=O]/1-2/a6-b1", |
| "WURCS=2.0/2,3,2/[a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O]/1-1-2/a4-b1_b3-c1", |
| "WURCS=2.0/3,3,2/[a2122A-1b_1-5][a211h-1a_1-5][a2122h-1b_1-5]/1-2-3/a2-b1_b2-c1", |
| "WURCS=2.0/2,4,3/[a2122h-1a_1-5][a2112h-1a_1-5]/1-1-2-2/a4-b1_b4-c1_b6-d1", |
| "WURCS=2.0/2,3,2/[a2112h-1b_1-5][a2112h-1a_1-4]/1-2-2/a4-b1_b2-c1", |
| "WURCS=2.0/1,1,0/[a122h-1a_1-5_2*OC_3*OC_4*OC]/1/", |
| "WURCS=2.0/1,1,0/[a122h-1a_1-4_2*OCC/3=O_3*OCC/3=O]/1/", |
| "WURCS=2.0/1,1,0/[had22111m-2b_2-6_5*N_7*N]/1/", |
| "WURCS=2.0/1,1,0/[Aad22111m-2b_2-6_5*N_7*N]/1/", |
| "WURCS=2.0/4,7,6/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_5*N_7*N]/1-2-1-2-3-4-3/a3-b1_b3-c1_b6-g1_c3-d1_d6-e1_e6-f2", |
| "WURCS=2.0/4,7,6/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_7*N_5*N]/1-2-3-1-2-3-4/a3-b1_b3-d1_b6-c1_d3-e1_e6-f1_f6-g2", |
| "WURCS=2.0/1,1,0/[Aad22111m-2a_2-6_5*N_7*N]/1/", |
| "WURCS=2.0/5,8,7/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2a_2-6_5*N_7*N][Aad22111m-2b_2-6_5*N_7*N]/1-2-1-2-3-4-3-5/a3-b1_b3-c1_b6-g1_c3-d1_d6-e1_e6-f2_g6-h2", |
| "WURCS=2.0/5,8,7/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_7*N_5*N][Aad22111m-2a_2-6_7*N_5*N]/1-2-3-4-1-2-3-5/a3-b1_b3-e1_b6-c1_c6-d2_e3-f1_f6-g1_g6-h2", |
| "WURCS=2.0/3,3,2/[a2112m-1b_1-5_2*NCC/3=O][a212h-1b_1-5][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a3-b1_b4-c2", |
| "WURCS=2.0/3,3,2/[a212h-1b_1-5][a2112m-1b_1-5_2*NCC/3=O][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a4-b1_b3-c2", |
| "WURCS=2.0/3,3,2/[a212h-1b_1-5][a2112m-1b_1-5_2*NCC/3=O_4*OCC/3=O][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a4-b1_b3-c2", |
| ] |
|
|
|
|
| DEFAULT_AMBIGUOUS_WURCS = [ |
| "WURCS=2.0/2,2,1/[u2112h][a1221m-1a_1-5]/1-2/a2-b1", |
| "WURCS=2.0/3,3,2/[u2112h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1", |
| "WURCS=2.0/4,6,5/[u2122h][a2112h-1b_1-5][a1221m-1a_1-5][a2112h-1a_1-5_2*NCC/3=O]/1-2-3-4-2-3/a4-b1_b2-c1_b3-d1_d3-e1_e2-f1", |
| "WURCS=2.0/4,6,5/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5][a2112h-1a_1-5_2*NCC/3=O]/1-2-3-4-2-3/a4-b1_b2-c1_b3-d1_d3-e1_e2-f1", |
| "WURCS=2.0/4,4,3/[u2112h][a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4/a3-b1_b3-c1_c2-d1", |
| "WURCS=2.0/5,6,5/[u2122h][a2112h-1b_1-5][a2112h-1a_1-5][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-4-2-5/a4-b1_b4-c1_c3-d1_d3-e1_e2-f1", |
| "WURCS=2.0/4,5,4/[u2122h][a2112h-1b_1-5][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4/a4-b1_b3-c1_c3-d1_d2-e1", |
| "WURCS=2.0/5,6,5/[u2122h][a2112h-1b_1-5][Aad21122h-2a_2-6_5*NCC/3=O][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-4-2-5/a4-b1_b3-c2_b4-d1_d3-e1_e2-f1", |
| "WURCS=2.0/3,3,2/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1", |
| "WURCS=2.0/6,11,10/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-4-2-5-6/a4-b1_b4-c1_c3-d1_c6-h1_d2-e1_e3-f1_f2-g1_h2-i1_i3-j1_j2-k1", |
| "WURCS=2.0/6,12,11/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-4-2-5-6-6/a4-b1_a6-l1_b4-c1_c3-d1_c6-h1_d2-e1_e3-f1_f2-g1_h2-i1_i3-j1_j2-k1", |
| "WURCS=2.0/4,5,4/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4/a4-b1_b3-c1_c3-d1_d2-e1", |
| "WURCS=2.0/4,4,3/[u2112h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4/a3-b1_b3-c1_c2-d1", |
| "WURCS=2.0/4,7,6/[u2112h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-3-4/a3-b1_a6-e1_b3-c1_c2-d1_e3-f1_f2-g1", |
| "WURCS=2.0/4,8,7/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4-3-4-2/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f3-g1_f4-h1", |
| "WURCS=2.0/4,7,6/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4-3-2/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f4-g1", |
| "WURCS=2.0/5,8,7/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5][Aad21122h-2a_2-6_5*NCC/3=O]/1-2-3-2-4-3-2-5/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f4-g1_g6-h2", |
| "WURCS=2.0/3,3,2/[u2122h_2*NCC/3=O_6*OSO/3=O/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1", |
| "WURCS=2.0/3,4,3/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-3/a3-b1_a4-d1_b2-c1", |
| "WURCS=2.0/6,13,12/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-6-4-2-5-6-6/a4-b1_b4-c1_c3-d1_c6-i1_d2-e1_e3-f1_e4-h1_f2-g1_i2-j1_j3-k1_j4-m1_k2-l1", |
| ] |
|
|
|
|
| def get_device(name: str) -> torch.device: |
| if name == "auto": |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return torch.device(name) |
|
|
|
|
| def ensure_writable_hf_home(cache_dir: Path) -> None: |
| """Keep Hugging Face/Xet cache writes inside the selected work directory.""" |
| fallback = cache_dir.parent / "hf_home" |
| current = os.environ.get("HF_HOME") |
| candidates = [Path(current).expanduser()] if current else [] |
| candidates.append(fallback) |
|
|
| for candidate in candidates: |
| try: |
| candidate.mkdir(parents=True, exist_ok=True) |
| probe = candidate / ".write_probe" |
| probe.write_text("ok", encoding="utf-8") |
| probe.unlink() |
| except OSError: |
| continue |
| os.environ["HF_HOME"] = str(candidate) |
| os.environ.setdefault("HF_HUB_CACHE", str(candidate / "hub")) |
| os.environ.setdefault("HF_XET_CACHE", str(candidate / "xet")) |
| return |
|
|
| raise RuntimeError(f"Could not create a writable Hugging Face cache under {fallback}") |
|
|
|
|
| def download_repos(cache_dir: Path) -> Dict[str, Path]: |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| ensure_writable_hf_home(cache_dir) |
| paths: Dict[str, Path] = {} |
| for key, repo_id in REPOS.items(): |
| local_dir = cache_dir / repo_id.split("/")[-1] |
| paths[key] = Path( |
| snapshot_download( |
| repo_id=repo_id, |
| repo_type="model", |
| local_dir=str(local_dir), |
| token=False, |
| ) |
| ) |
| return paths |
|
|
|
|
| def add_source_paths(paths: Dict[str, Path]) -> None: |
| for key in ["affinose", "encoder", "resolver"]: |
| src = paths[key] / "src" |
| if str(src) not in sys.path: |
| sys.path.insert(0, str(src)) |
| os.environ["BERTOSE_ROOT"] = str(paths["encoder"]) |
| os.environ["BERTOSE_REPO_ROOT"] = str(paths["encoder"]) |
|
|
|
|
| def count_layers(state_dict: Dict[str, torch.Tensor], prefix: str, default: int) -> int: |
| pattern = re.compile(rf"^{re.escape(prefix)}\.(\d+)\.") |
| indices = [int(match.group(1)) for key in state_dict for match in [pattern.match(key)] if match] |
| return max(indices) + 1 if indices else default |
|
|
|
|
| def import_bertose_classes(): |
| from bertose_model import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
| from wurcs_bpe_tokenizer import WURCSBPETokenizer |
|
|
| return MultimodalGlycanBERT, MultimodalGlycanBERTConfig, WURCSBPETokenizer |
|
|
|
|
| def build_config_from_state_dict(state_dict: Dict[str, torch.Tensor]): |
| _, MultimodalGlycanBERTConfig, _ = import_bertose_classes() |
| seq_vocab_size, seq_hidden = state_dict["seq_embeddings.token_embeddings.weight"].shape |
| seq_max_length = state_dict["seq_embeddings.position_embeddings.weight"].shape[0] |
| ms_vocab_size = 242 |
| ms_hidden_size = 384 |
| if "ms_embeddings.token_embeddings.weight" in state_dict: |
| ms_total_vocab, ms_hidden_size = state_dict["ms_embeddings.token_embeddings.weight"].shape |
| ms_vocab_size = max(1, ms_total_vocab - seq_vocab_size) |
| struct_vocab_size = 1024 |
| struct_hidden_size = 512 |
| use_3d = any(key.startswith("struct_embeddings.") for key in state_dict) |
| if "struct_embeddings.token_embeddings.weight" in state_dict: |
| struct_vocab_size, struct_hidden_size = state_dict["struct_embeddings.token_embeddings.weight"].shape |
| return MultimodalGlycanBERTConfig( |
| seq_vocab_size=seq_vocab_size, |
| seq_hidden_size=seq_hidden, |
| seq_num_layers=count_layers(state_dict, "seq_layers", 12), |
| seq_num_heads=12, |
| seq_max_length=seq_max_length, |
| ms_vocab_size=ms_vocab_size, |
| ms_hidden_size=ms_hidden_size, |
| ms_num_layers=count_layers(state_dict, "ms_layers", 6), |
| ms_num_heads=6, |
| ms_max_length=150, |
| struct_vocab_size=struct_vocab_size, |
| struct_hidden_size=struct_hidden_size, |
| struct_num_layers=count_layers(state_dict, "struct_layers", 8), |
| struct_num_heads=8, |
| struct_max_length=200, |
| use_3d=use_3d, |
| use_cross_attention=any(key.startswith("cross_attention.") for key in state_dict), |
| use_cnn_frontend=any(key.startswith("seq_embeddings.conv_layers.") for key in state_dict), |
| cnn_kernel_size=3, |
| ) |
|
|
|
|
| def load_bertose_backbone(checkpoint_path: Path, device: torch.device): |
| MultimodalGlycanBERT, _, _ = import_bertose_classes() |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| state_dict = checkpoint.get("model_state_dict", checkpoint) |
| state_dict = {key: value for key, value in state_dict.items() if not key.startswith("proj_head.")} |
| config = build_config_from_state_dict(state_dict) |
| model = MultimodalGlycanBERT(config) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| missing_report = [key for key in missing if not key.startswith(("dist_proj.", "distance_head."))] |
| unexpected_report = [key for key in unexpected if not key.startswith(("proj_head.", "distance_head."))] |
| if missing_report: |
| warnings.warn(f"{checkpoint_path.name}: {len(missing_report)} unexpected missing tensors") |
| if unexpected_report: |
| warnings.warn(f"{checkpoint_path.name}: {len(unexpected_report)} unexpected tensors") |
| model._dist_none_printed = True |
| model.to(device).eval() |
| return model, config |
|
|
|
|
| def load_tokenizer(vocab_path: Path): |
| _, _, WURCSBPETokenizer = import_bertose_classes() |
| return WURCSBPETokenizer(str(vocab_path)) |
|
|
|
|
| def tensorize_tokenized(tokenized_rows: Sequence[Dict[str, Any]], device: torch.device) -> Dict[str, torch.Tensor]: |
| keys = ["token_ids", "attention_mask", "residue_ids", "branch_depths", "linkage_types"] |
| return {key: torch.tensor([row[key] for row in tokenized_rows], dtype=torch.long, device=device) for key in keys} |
|
|
|
|
| def nonpad_tokens(tokenizer, token_ids: Sequence[int], attention_mask: Sequence[int]) -> List[str]: |
| return [tokenizer.id_to_token.get(int(tid), "[UNK]") for tid, keep in zip(token_ids, attention_mask) if int(keep) == 1] |
|
|
|
|
| class BertoseEmbedder: |
| def __init__(self, checkpoint_path: Path, vocab_path: Path, device: torch.device): |
| self.device = device |
| self.tokenizer = load_tokenizer(vocab_path) |
| self.model, self.config = load_bertose_backbone(checkpoint_path, device=device) |
|
|
| @torch.no_grad() |
| def _encode_hidden(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| hidden = self.model.seq_embeddings( |
| batch["token_ids"], |
| branch_depths=batch["branch_depths"], |
| linkage_types=batch["linkage_types"], |
| ) |
| for layer in self.model.seq_layers: |
| hidden = layer(hidden, batch["attention_mask"]) |
| return hidden |
|
|
| @torch.no_grad() |
| def embed_batch(self, wurcs_list: Sequence[str], batch_size: int, pooling: str = "cls") -> np.ndarray: |
| embeddings: List[np.ndarray] = [] |
| for start in tqdm(range(0, len(wurcs_list), batch_size), desc="Embedding glycans"): |
| chunk = list(wurcs_list[start : start + batch_size]) |
| tokenized = [self.tokenizer.tokenize(wurcs, max_length=self.config.seq_max_length) for wurcs in chunk] |
| batch = tensorize_tokenized(tokenized, self.device) |
| hidden = self._encode_hidden(batch) |
| if pooling == "cls": |
| pooled = hidden[:, 0, :] |
| elif pooling == "mean": |
| mask = batch["attention_mask"].float().unsqueeze(-1) |
| pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0) |
| else: |
| raise ValueError("pooling must be 'cls' or 'mean'") |
| embeddings.append(pooled.detach().cpu().numpy().astype("float32")) |
| return np.vstack(embeddings) |
|
|
| def embed_single(self, wurcs: str, pooling: str = "cls") -> np.ndarray: |
| return self.embed_batch([wurcs], batch_size=1, pooling=pooling)[0] |
|
|
|
|
| class BertoseAmbiguityResolver: |
| def __init__(self, checkpoint_path: Path, vocab_path: Path, ambiguity_path: Path, device: torch.device, threshold: float = 0.80, top_k: int = 5): |
| self.device = device |
| self.threshold = threshold |
| self.top_k = top_k |
| self.tokenizer = load_tokenizer(vocab_path) |
| with open(ambiguity_path, encoding="utf-8") as handle: |
| ambiguity = json.load(handle) |
| self.ambiguous_ids = set(int(item) for item in ambiguity.get("ambiguous_ids", [])) |
| self.ambiguous_ids.update( |
| int(token_id) |
| for token, token_id in self.tokenizer.token_to_id.items() |
| if token not in {"[PAD]", "[UNK]", "[START]", "[END]", "[MASK]"} |
| and ("?" in token or "x" in token or token.startswith("u")) |
| ) |
| self.special_ids = { |
| int(token_id) |
| for token, token_id in self.tokenizer.token_to_id.items() |
| if token.startswith("[") and token.endswith("]") |
| } |
| self.invalid_prediction_ids = self.ambiguous_ids | self.special_ids |
| self.mask_id = self.tokenizer.token_to_id.get("[MASK]", 4) |
| self.model, self.config = load_bertose_backbone(checkpoint_path, device=device) |
|
|
| def _ambiguous_positions(self, tokenized: Dict[str, Any]) -> List[int]: |
| length = int(tokenized.get("length", sum(tokenized["attention_mask"]))) |
| return [idx for idx, tid in enumerate(tokenized["token_ids"][:length]) if int(tid) in self.ambiguous_ids] |
|
|
| @torch.no_grad() |
| def _forward_logits(self, rows: Sequence[Dict[str, Any]], current_ids: Sequence[List[int]]) -> torch.Tensor: |
| batch_size = len(rows) |
| outputs = self.model( |
| seq_token_ids=torch.tensor(current_ids, dtype=torch.long, device=self.device), |
| seq_attention_mask=torch.tensor([row["attention_mask"] for row in rows], dtype=torch.long, device=self.device), |
| seq_residue_ids=torch.tensor([row["residue_ids"] for row in rows], dtype=torch.long, device=self.device), |
| seq_branch_depths=torch.tensor([row["branch_depths"] for row in rows], dtype=torch.long, device=self.device), |
| seq_linkage_types=torch.tensor([row["linkage_types"] for row in rows], dtype=torch.long, device=self.device), |
| ms_token_ids=torch.zeros(batch_size, 150, dtype=torch.long, device=self.device), |
| ms_attention_mask=torch.zeros(batch_size, 150, dtype=torch.long, device=self.device), |
| has_ms=torch.zeros(batch_size, dtype=torch.bool, device=self.device), |
| struct_token_ids=torch.zeros(batch_size, 200, dtype=torch.long, device=self.device), |
| struct_attention_mask=torch.zeros(batch_size, 200, dtype=torch.long, device=self.device), |
| struct_residue_ids=torch.full((batch_size, 200), -1, dtype=torch.long, device=self.device), |
| has_3d=torch.zeros(batch_size, dtype=torch.bool, device=self.device), |
| return_dict=True, |
| ) |
| return outputs["seq_logits"] |
|
|
| def resolve_batch(self, wurcs_list: Sequence[str], ids: Optional[Sequence[str]], iterations: int, batch_size: int) -> Tuple[pd.DataFrame, pd.DataFrame]: |
| ids = list(ids) if ids is not None else [str(i) for i in range(len(wurcs_list))] |
| summaries: List[Dict[str, Any]] = [] |
| details: List[Dict[str, Any]] = [] |
| for start in tqdm(range(0, len(wurcs_list), batch_size), desc="Resolving glycans"): |
| chunk_wurcs = list(wurcs_list[start : start + batch_size]) |
| chunk_ids = list(ids[start : start + batch_size]) |
| tokenized_rows = [self.tokenizer.tokenize(wurcs, max_length=self.config.seq_max_length) for wurcs in chunk_wurcs] |
| current_ids = [list(row["token_ids"]) for row in tokenized_rows] |
| unresolved = [set(self._ambiguous_positions(row)) for row in tokenized_rows] |
| initial_counts = [len(item) for item in unresolved] |
| accepted_by_row: List[List[Dict[str, Any]]] = [[] for _ in tokenized_rows] |
| for iteration in range(1, iterations + 1): |
| masked_ids = [] |
| for row_idx, row_ids in enumerate(current_ids): |
| masked = list(row_ids) |
| for pos in unresolved[row_idx]: |
| masked[pos] = self.mask_id |
| masked_ids.append(masked) |
| logits = self._forward_logits(tokenized_rows, masked_ids) |
| progress = 0 |
| for row_idx, positions in enumerate(unresolved): |
| for pos in sorted(list(positions)): |
| probs = F.softmax(logits[row_idx, pos], dim=-1) |
| top_probs, top_ids = torch.topk(probs, k=self.top_k) |
| top_ids_list = [int(item) for item in top_ids.detach().cpu().tolist()] |
| top_probs_list = [float(item) for item in top_probs.detach().cpu().tolist()] |
| top_tokens = [self.tokenizer.id_to_token.get(item, "[UNK]") for item in top_ids_list] |
| accepted_candidate = next( |
| ( |
| (candidate_id, candidate_prob) |
| for candidate_id, candidate_prob in zip(top_ids_list, top_probs_list) |
| if candidate_prob >= self.threshold and candidate_id not in self.invalid_prediction_ids |
| ), |
| None, |
| ) |
| detail = { |
| "sample_id": chunk_ids[row_idx], |
| "wurcs": chunk_wurcs[row_idx], |
| "iteration": iteration, |
| "position": pos, |
| "original_id": int(tokenized_rows[row_idx]["token_ids"][pos]), |
| "original_token": self.tokenizer.id_to_token.get(int(tokenized_rows[row_idx]["token_ids"][pos]), "[UNK]"), |
| "top_ids": json.dumps(top_ids_list), |
| "top_tokens": json.dumps(top_tokens), |
| "top_probs": json.dumps(top_probs_list), |
| "accepted": False, |
| } |
| if accepted_candidate is not None: |
| accepted_id, accepted_prob = accepted_candidate |
| current_ids[row_idx][pos] = accepted_id |
| positions.remove(pos) |
| progress += 1 |
| detail.update( |
| { |
| "accepted": True, |
| "resolved_id": accepted_id, |
| "resolved_token": self.tokenizer.id_to_token.get(accepted_id, "[UNK]"), |
| "confidence": accepted_prob, |
| } |
| ) |
| accepted_by_row[row_idx].append(detail.copy()) |
| details.append(detail) |
| if progress == 0: |
| break |
| for row_idx, row in enumerate(tokenized_rows): |
| summaries.append( |
| { |
| "sample_id": chunk_ids[row_idx], |
| "wurcs": chunk_wurcs[row_idx], |
| "initial_ambiguous_tokens": initial_counts[row_idx], |
| "resolved_tokens": len(accepted_by_row[row_idx]), |
| "remaining_ambiguous_tokens": len(unresolved[row_idx]), |
| "final_token_sequence": " ".join(nonpad_tokens(self.tokenizer, current_ids[row_idx], row["attention_mask"])), |
| "accepted_updates_json": json.dumps(accepted_by_row[row_idx]), |
| } |
| ) |
| return pd.DataFrame(summaries), pd.DataFrame(details) |
|
|
| def resolve_single(self, wurcs: str, iterations: int) -> Tuple[pd.DataFrame, pd.DataFrame]: |
| return self.resolve_batch([wurcs], ids=["single"], iterations=iterations, batch_size=1) |
|
|
|
|
| def read_wurcs_csv(path: Optional[Path], default: Sequence[str], column: str, n: int) -> List[str]: |
| if path and path.exists(): |
| df = pd.read_csv(path) |
| if column not in df.columns: |
| raise ValueError(f"{path} must contain column {column!r}") |
| vals = df[column].dropna().astype(str).tolist() |
| else: |
| vals = list(default) |
| vals = [item for item in vals if item.startswith("WURCS=")] |
| if len(vals) < n: |
| raise ValueError(f"Need at least {n} WURCS strings, found {len(vals)}") |
| return vals[:n] |
|
|
|
|
| def choose_affinose_pairs( |
| csv_path: Optional[Path], |
| protein_h5: Optional[Path], |
| output_dir: Path, |
| n: int, |
| default_wurcs: Sequence[str], |
| ) -> Tuple[pd.DataFrame, Path, str]: |
| if protein_h5 and protein_h5.exists(): |
| with h5py.File(protein_h5, "r") as handle: |
| available = set(handle.keys()) |
| if csv_path and csv_path.exists(): |
| df = pd.read_csv(csv_path, usecols=lambda col: col in {"protein_id", "glycan_wurcs"}) |
| df = df.dropna(subset=["protein_id", "glycan_wurcs"]).copy() |
| df["protein_id"] = df["protein_id"].astype(str) |
| df["glycan_wurcs"] = df["glycan_wurcs"].astype(str) |
| df = df[df["protein_id"].isin(available) & df["glycan_wurcs"].str.startswith("WURCS=")] |
| df = df.drop_duplicates(["protein_id", "glycan_wurcs"]) |
| diverse = df.drop_duplicates("protein_id").head(n) |
| if len(diverse) < n: |
| extra = df[~df.index.isin(diverse.index)].head(n - len(diverse)) |
| diverse = pd.concat([diverse, extra], ignore_index=False) |
| pairs = diverse.head(n).copy() |
| else: |
| keys = sorted(available)[:n] |
| pairs = pd.DataFrame({"protein_id": keys, "glycan_wurcs": list(default_wurcs)[:n]}) |
| if len(pairs) < n: |
| raise ValueError(f"Need {n} AFFINose pairs with available protein embeddings, found {len(pairs)}") |
| subset_h5 = output_dir / "affinose_subset_esmc_embeddings.h5" |
| with h5py.File(protein_h5, "r") as src, h5py.File(subset_h5, "w") as dst: |
| for protein_id in pairs["protein_id"].astype(str).unique(): |
| dst.create_dataset(protein_id, data=src[protein_id][:], compression="gzip") |
| return pairs.reset_index(drop=True), subset_h5, "real_subset" |
|
|
| synthetic_h5 = output_dir / "synthetic_esmc_embeddings.h5" |
| rng = np.random.default_rng(123) |
| protein_ids = [f"smoke_protein_{idx}" for idx in range(n)] |
| with h5py.File(synthetic_h5, "w") as dst: |
| for idx, protein_id in enumerate(protein_ids): |
| dst.create_dataset(protein_id, data=rng.normal(size=(64 + idx, 960)).astype("float32")) |
| pairs = pd.DataFrame({"protein_id": protein_ids, "glycan_wurcs": list(default_wurcs)[:n]}) |
| return pairs, synthetic_h5, "synthetic" |
|
|
|
|
| def write_embedding_csv(path: Path, sample_ids: Sequence[str], wurcs: Sequence[str], embeddings: np.ndarray) -> None: |
| emb_cols = [f"emb_{idx:03d}" for idx in range(embeddings.shape[1])] |
| df = pd.concat( |
| [ |
| pd.DataFrame({"sample_id": sample_ids, "wurcs": wurcs}), |
| pd.DataFrame(embeddings, columns=emb_cols), |
| ], |
| axis=1, |
| ) |
| df.to_csv(path, index=False) |
|
|
|
|
| def finite_or_raise(values: Sequence[float], name: str) -> None: |
| arr = np.asarray(values, dtype="float64") |
| if not np.isfinite(arr).all(): |
| raise RuntimeError(f"{name} contains non-finite values") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Validate public BERTose/AFFINose Hugging Face inference.") |
| parser.add_argument("--output-dir", type=Path, default=Path("release_smoke_outputs")) |
| parser.add_argument("--hf-cache-dir", type=Path, default=Path("hf_release_cache")) |
| parser.add_argument("--device", default="auto") |
| parser.add_argument("--embedding-csv", type=Path, default=None) |
| parser.add_argument("--iar-csv", type=Path, default=None) |
| parser.add_argument("--affinose-csv", type=Path, default=None) |
| parser.add_argument("--protein-emb-h5", type=Path, default=None) |
| parser.add_argument("--embedding-n", type=int, default=20) |
| parser.add_argument("--iar-n", type=int, default=20) |
| parser.add_argument("--affinose-n", type=int, default=5) |
| parser.add_argument("--embedding-batch-size", type=int, default=20) |
| parser.add_argument("--iar-batch-size", type=int, default=8) |
| parser.add_argument("--affinose-batch-size", type=int, default=5) |
| parser.add_argument("--iar-iterations", type=int, default=3) |
| args = parser.parse_args() |
|
|
| torch.set_grad_enabled(False) |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| device = get_device(args.device) |
| summary: Dict[str, Any] = { |
| "device": str(device), |
| "cuda_available": torch.cuda.is_available(), |
| "torch_version": torch.__version__, |
| "repo_ids": REPOS, |
| "timing_seconds": {}, |
| } |
|
|
| t0 = time.perf_counter() |
| paths = download_repos(args.hf_cache_dir) |
| summary["timing_seconds"]["download_repos"] = time.perf_counter() - t0 |
| summary["repo_paths"] = {key: str(path) for key, path in paths.items()} |
| add_source_paths(paths) |
|
|
| embedding_wurcs = read_wurcs_csv(args.embedding_csv, DEFAULT_WURCS, "wurcs", args.embedding_n) |
| iar_wurcs = read_wurcs_csv(args.iar_csv, DEFAULT_AMBIGUOUS_WURCS, "wurcs", args.iar_n) |
|
|
| encoder_ckpt = paths["encoder"] / "checkpoints" / "bertose_glycan_encoder.pt" |
| encoder_vocab = paths["encoder"] / "vocab" / "bpe_vocabulary.json" |
| resolver_ckpt = paths["resolver"] / "checkpoints" / "bertose_iar_resolver.pt" |
| resolver_vocab = paths["resolver"] / "vocab" / "bpe_vocabulary.json" |
| ambiguity_path = paths["resolver"] / "vocab" / "bpe_ambiguity_tokens.json" |
|
|
| t0 = time.perf_counter() |
| embedder = BertoseEmbedder(encoder_ckpt, encoder_vocab, device) |
| summary["timing_seconds"]["load_bertose_embedder"] = time.perf_counter() - t0 |
|
|
| t0 = time.perf_counter() |
| single_embedding = embedder.embed_single(embedding_wurcs[0]) |
| summary["timing_seconds"]["bertose_embedding_single"] = time.perf_counter() - t0 |
| if single_embedding.shape != (768,): |
| raise RuntimeError(f"Expected single embedding shape (768,), got {single_embedding.shape}") |
| finite_or_raise(single_embedding, "single_embedding") |
|
|
| t0 = time.perf_counter() |
| batch_embeddings = embedder.embed_batch(embedding_wurcs, batch_size=args.embedding_batch_size) |
| elapsed = time.perf_counter() - t0 |
| summary["timing_seconds"]["bertose_embedding_batch"] = elapsed |
| summary["bertose_embedding_batch_shape"] = list(batch_embeddings.shape) |
| summary["bertose_embedding_items_per_second"] = len(embedding_wurcs) / max(elapsed, 1e-9) |
| if batch_embeddings.shape != (len(embedding_wurcs), 768): |
| raise RuntimeError(f"Expected batch embedding shape ({len(embedding_wurcs)}, 768), got {batch_embeddings.shape}") |
| finite_or_raise(batch_embeddings.ravel(), "batch_embeddings") |
| write_embedding_csv(args.output_dir / "bertose_embeddings_20.csv", [f"glycan_{idx:02d}" for idx in range(len(embedding_wurcs))], embedding_wurcs, batch_embeddings) |
|
|
| del embedder |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| t0 = time.perf_counter() |
| resolver = BertoseAmbiguityResolver(resolver_ckpt, resolver_vocab, ambiguity_path, device) |
| summary["timing_seconds"]["load_bertose_iar"] = time.perf_counter() - t0 |
|
|
| t0 = time.perf_counter() |
| single_summary, single_details = resolver.resolve_single(iar_wurcs[0], iterations=args.iar_iterations) |
| summary["timing_seconds"]["bertose_iar_single"] = time.perf_counter() - t0 |
| if single_summary.empty: |
| raise RuntimeError("BERTose IAR single output is empty") |
| single_summary.to_csv(args.output_dir / "bertose_iar_single_summary.csv", index=False) |
| single_details.to_csv(args.output_dir / "bertose_iar_single_details.csv", index=False) |
|
|
| t0 = time.perf_counter() |
| iar_summary, iar_details = resolver.resolve_batch( |
| iar_wurcs, |
| ids=[f"ambiguous_{idx:02d}" for idx in range(len(iar_wurcs))], |
| iterations=args.iar_iterations, |
| batch_size=args.iar_batch_size, |
| ) |
| elapsed = time.perf_counter() - t0 |
| summary["timing_seconds"]["bertose_iar_batch"] = elapsed |
| summary["bertose_iar_batch_rows"] = len(iar_summary) |
| summary["bertose_iar_detail_rows"] = len(iar_details) |
| summary["bertose_iar_items_per_second"] = len(iar_wurcs) / max(elapsed, 1e-9) |
| if len(iar_summary) != len(iar_wurcs): |
| raise RuntimeError(f"Expected {len(iar_wurcs)} IAR summary rows, got {len(iar_summary)}") |
| iar_summary.to_csv(args.output_dir / "bertose_iar_batch_summary.csv", index=False) |
| iar_details.to_csv(args.output_dir / "bertose_iar_batch_details.csv", index=False) |
|
|
| del resolver |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| pairs, protein_h5, protein_source = choose_affinose_pairs( |
| args.affinose_csv, |
| args.protein_emb_h5, |
| args.output_dir, |
| args.affinose_n, |
| embedding_wurcs, |
| ) |
| pairs_path = args.output_dir / "affinose_pairs.csv" |
| pairs.to_csv(pairs_path, index=False) |
| summary["affinose_protein_embedding_source"] = protein_source |
| summary["affinose_pairs"] = len(pairs) |
|
|
| t0 = time.perf_counter() |
| from affinose_inference import AffinosePredictor |
|
|
| predictor = AffinosePredictor( |
| checkpoint_path=str(paths["affinose"] / "checkpoints" / "affinose_interaction_model.pt"), |
| bertose_checkpoint=str(encoder_ckpt), |
| vocab_path=str(paths["affinose"] / "vocab" / "bpe_vocabulary.json"), |
| protein_emb_path=str(protein_h5), |
| device=str(device), |
| ) |
| summary["timing_seconds"]["load_affinose"] = time.perf_counter() - t0 |
|
|
| t0 = time.perf_counter() |
| first = pairs.iloc[0] |
| single_score = predictor.predict_single(first["glycan_wurcs"], first["protein_id"]) |
| summary["timing_seconds"]["affinose_single"] = time.perf_counter() - t0 |
| finite_or_raise([single_score], "affinose_single_score") |
| summary["affinose_single_score"] = float(single_score) |
|
|
| t0 = time.perf_counter() |
| scores = predictor.predict_batch( |
| pairs["glycan_wurcs"].astype(str).tolist(), |
| pairs["protein_id"].astype(str).tolist(), |
| batch_size=args.affinose_batch_size, |
| ) |
| elapsed = time.perf_counter() - t0 |
| finite_or_raise(scores, "affinose_batch_scores") |
| summary["timing_seconds"]["affinose_batch"] = elapsed |
| summary["affinose_items_per_second"] = len(scores) / max(elapsed, 1e-9) |
| out_pairs = pairs.copy() |
| out_pairs["affinose_score"] = [float(score) for score in scores] |
| out_pairs.to_csv(args.output_dir / "affinose_predictions_5.csv", index=False) |
|
|
| summary["status"] = "ok" |
| summary_path = args.output_dir / "smoke_summary.json" |
| summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
| print(json.dumps(summary, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|