#!/usr/bin/env python3 """Headless BERTose/AFFINose Hugging Face release smoke test. This script is meant for cluster or cloud notebook validation. It downloads the public release repositories without a user token, runs single and batch BERTose embedding, BERTose IAR, and AFFINose protein-glycan scoring, then writes CSV and JSON outputs with timing. """ from __future__ import annotations import argparse import json import os import re import sys import time import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple import h5py import numpy as np import pandas as pd import torch import torch.nn.functional as F from huggingface_hub import snapshot_download from tqdm.auto import tqdm REPOS = { "encoder": "supanthadey1/bertose-glycan-encoder", "resolver": "supanthadey1/bertose-iar-resolver", "affinose": "supanthadey1/affinose-interaction-model", "notebook": "supanthadey1/bertose-affinose-inference", } DEFAULT_WURCS = [ "WURCS=2.0/2,2,1/[a212h-1b_1-5][a2211m-1a_1-5]/1-2/a2-b1", "WURCS=2.0/2,4,3/[a1221m-1b_1-5][a1221m-1a_1-5]/1-2-2-2/a3-b1_b3-c1_c3-d1", "WURCS=2.0/3,3,2/[a2112h-1b_1-5_2*NCC/3=O_4*OSO/3=O/3=O_6*OSO/3=O/3=O][a2122A-1b_1-5][a1221m-1a_1-5_2*OSO/3=O/3=O_4*OSO/3=O/3=O]/1-2-3/a3-b1_b3-c1", "WURCS=2.0/2,2,1/[a2122h-1a_1-5_2*OSO/3=O/3=O_3*OSO/3=O/3=O_4*OSO/3=O/3=O][a2122m-1a_1-5_2*OSO/3=O/3=O_3*OSO/3=O/3=O]/1-2/a6-b1", "WURCS=2.0/2,3,2/[a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O]/1-1-2/a4-b1_b3-c1", "WURCS=2.0/3,3,2/[a2122A-1b_1-5][a211h-1a_1-5][a2122h-1b_1-5]/1-2-3/a2-b1_b2-c1", "WURCS=2.0/2,4,3/[a2122h-1a_1-5][a2112h-1a_1-5]/1-1-2-2/a4-b1_b4-c1_b6-d1", "WURCS=2.0/2,3,2/[a2112h-1b_1-5][a2112h-1a_1-4]/1-2-2/a4-b1_b2-c1", "WURCS=2.0/1,1,0/[a122h-1a_1-5_2*OC_3*OC_4*OC]/1/", "WURCS=2.0/1,1,0/[a122h-1a_1-4_2*OCC/3=O_3*OCC/3=O]/1/", "WURCS=2.0/1,1,0/[had22111m-2b_2-6_5*N_7*N]/1/", "WURCS=2.0/1,1,0/[Aad22111m-2b_2-6_5*N_7*N]/1/", "WURCS=2.0/4,7,6/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_5*N_7*N]/1-2-1-2-3-4-3/a3-b1_b3-c1_b6-g1_c3-d1_d6-e1_e6-f2", "WURCS=2.0/4,7,6/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_7*N_5*N]/1-2-3-1-2-3-4/a3-b1_b3-d1_b6-c1_d3-e1_e6-f1_f6-g2", "WURCS=2.0/1,1,0/[Aad22111m-2a_2-6_5*N_7*N]/1/", "WURCS=2.0/5,8,7/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2a_2-6_5*N_7*N][Aad22111m-2b_2-6_5*N_7*N]/1-2-1-2-3-4-3-5/a3-b1_b3-c1_b6-g1_c3-d1_d6-e1_e6-f2_g6-h2", "WURCS=2.0/5,8,7/[a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a2122h-1b_1-5][Aad22111m-2b_2-6_7*N_5*N][Aad22111m-2a_2-6_7*N_5*N]/1-2-3-4-1-2-3-5/a3-b1_b3-e1_b6-c1_c6-d2_e3-f1_f6-g1_g6-h2", "WURCS=2.0/3,3,2/[a2112m-1b_1-5_2*NCC/3=O][a212h-1b_1-5][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a3-b1_b4-c2", "WURCS=2.0/3,3,2/[a212h-1b_1-5][a2112m-1b_1-5_2*NCC/3=O][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a4-b1_b3-c2", "WURCS=2.0/3,3,2/[a212h-1b_1-5][a2112m-1b_1-5_2*NCC/3=O_4*OCC/3=O][Aad22111m-2a_2-6_5*N_7*NC=O]/1-2-3/a4-b1_b3-c2", ] DEFAULT_AMBIGUOUS_WURCS = [ "WURCS=2.0/2,2,1/[u2112h][a1221m-1a_1-5]/1-2/a2-b1", "WURCS=2.0/3,3,2/[u2112h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1", "WURCS=2.0/4,6,5/[u2122h][a2112h-1b_1-5][a1221m-1a_1-5][a2112h-1a_1-5_2*NCC/3=O]/1-2-3-4-2-3/a4-b1_b2-c1_b3-d1_d3-e1_e2-f1", "WURCS=2.0/4,6,5/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5][a2112h-1a_1-5_2*NCC/3=O]/1-2-3-4-2-3/a4-b1_b2-c1_b3-d1_d3-e1_e2-f1", "WURCS=2.0/4,4,3/[u2112h][a2112h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4/a3-b1_b3-c1_c2-d1", "WURCS=2.0/5,6,5/[u2122h][a2112h-1b_1-5][a2112h-1a_1-5][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-4-2-5/a4-b1_b4-c1_c3-d1_d3-e1_e2-f1", "WURCS=2.0/4,5,4/[u2122h][a2112h-1b_1-5][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4/a4-b1_b3-c1_c3-d1_d2-e1", "WURCS=2.0/5,6,5/[u2122h][a2112h-1b_1-5][Aad21122h-2a_2-6_5*NCC/3=O][a2112h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-4-2-5/a4-b1_b3-c2_b4-d1_d3-e1_e2-f1", "WURCS=2.0/3,3,2/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1", "WURCS=2.0/6,11,10/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-4-2-5-6/a4-b1_b4-c1_c3-d1_c6-h1_d2-e1_e3-f1_f2-g1_h2-i1_i3-j1_j2-k1", "WURCS=2.0/6,12,11/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-4-2-5-6-6/a4-b1_a6-l1_b4-c1_c3-d1_c6-h1_d2-e1_e3-f1_f2-g1_h2-i1_i3-j1_j2-k1", "WURCS=2.0/4,5,4/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4/a4-b1_b3-c1_c3-d1_d2-e1", "WURCS=2.0/4,4,3/[u2112h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4/a3-b1_b3-c1_c2-d1", "WURCS=2.0/4,7,6/[u2112h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-3-4/a3-b1_a6-e1_b3-c1_c2-d1_e3-f1_f2-g1", "WURCS=2.0/4,8,7/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4-3-4-2/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f3-g1_f4-h1", "WURCS=2.0/4,7,6/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5]/1-2-3-2-4-3-2/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f4-g1", "WURCS=2.0/5,8,7/[u2122h][a2112h-1b_1-5][a2122h-1b_1-5_2*NCC/3=O][a1221m-1a_1-5][Aad21122h-2a_2-6_5*NCC/3=O]/1-2-3-2-4-3-2-5/a4-b1_b3-c1_b6-f1_c3-d1_d2-e1_f4-g1_g6-h2", "WURCS=2.0/3,3,2/[u2122h_2*NCC/3=O_6*OSO/3=O/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3/a3-b1_b2-c1", "WURCS=2.0/3,4,3/[u2122h_2*NCC/3=O][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-3/a3-b1_a4-d1_b2-c1", "WURCS=2.0/6,13,12/[u2122h_2*NCC/3=O][a2122h-1b_1-5_2*NCC/3=O][a1122h-1b_1-5][a1122h-1a_1-5][a2112h-1b_1-5][a1221m-1a_1-5]/1-2-3-4-2-5-6-6-4-2-5-6-6/a4-b1_b4-c1_c3-d1_c6-i1_d2-e1_e3-f1_e4-h1_f2-g1_i2-j1_j3-k1_j4-m1_k2-l1", ] def get_device(name: str) -> torch.device: if name == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(name) def ensure_writable_hf_home(cache_dir: Path) -> None: """Keep Hugging Face/Xet cache writes inside the selected work directory.""" fallback = cache_dir.parent / "hf_home" current = os.environ.get("HF_HOME") candidates = [Path(current).expanduser()] if current else [] candidates.append(fallback) for candidate in candidates: try: candidate.mkdir(parents=True, exist_ok=True) probe = candidate / ".write_probe" probe.write_text("ok", encoding="utf-8") probe.unlink() except OSError: continue os.environ["HF_HOME"] = str(candidate) os.environ.setdefault("HF_HUB_CACHE", str(candidate / "hub")) os.environ.setdefault("HF_XET_CACHE", str(candidate / "xet")) return raise RuntimeError(f"Could not create a writable Hugging Face cache under {fallback}") def download_repos(cache_dir: Path) -> Dict[str, Path]: cache_dir.mkdir(parents=True, exist_ok=True) ensure_writable_hf_home(cache_dir) paths: Dict[str, Path] = {} for key, repo_id in REPOS.items(): local_dir = cache_dir / repo_id.split("/")[-1] paths[key] = Path( snapshot_download( repo_id=repo_id, repo_type="model", local_dir=str(local_dir), token=False, ) ) return paths def add_source_paths(paths: Dict[str, Path]) -> None: for key in ["affinose", "encoder", "resolver"]: src = paths[key] / "src" if str(src) not in sys.path: sys.path.insert(0, str(src)) os.environ["BERTOSE_ROOT"] = str(paths["encoder"]) os.environ["BERTOSE_REPO_ROOT"] = str(paths["encoder"]) def count_layers(state_dict: Dict[str, torch.Tensor], prefix: str, default: int) -> int: pattern = re.compile(rf"^{re.escape(prefix)}\.(\d+)\.") indices = [int(match.group(1)) for key in state_dict for match in [pattern.match(key)] if match] return max(indices) + 1 if indices else default def import_bertose_classes(): from bertose_model import MultimodalGlycanBERT, MultimodalGlycanBERTConfig from wurcs_bpe_tokenizer import WURCSBPETokenizer return MultimodalGlycanBERT, MultimodalGlycanBERTConfig, WURCSBPETokenizer def build_config_from_state_dict(state_dict: Dict[str, torch.Tensor]): _, MultimodalGlycanBERTConfig, _ = import_bertose_classes() seq_vocab_size, seq_hidden = state_dict["seq_embeddings.token_embeddings.weight"].shape seq_max_length = state_dict["seq_embeddings.position_embeddings.weight"].shape[0] ms_vocab_size = 242 ms_hidden_size = 384 if "ms_embeddings.token_embeddings.weight" in state_dict: ms_total_vocab, ms_hidden_size = state_dict["ms_embeddings.token_embeddings.weight"].shape ms_vocab_size = max(1, ms_total_vocab - seq_vocab_size) struct_vocab_size = 1024 struct_hidden_size = 512 use_3d = any(key.startswith("struct_embeddings.") for key in state_dict) if "struct_embeddings.token_embeddings.weight" in state_dict: struct_vocab_size, struct_hidden_size = state_dict["struct_embeddings.token_embeddings.weight"].shape return MultimodalGlycanBERTConfig( seq_vocab_size=seq_vocab_size, seq_hidden_size=seq_hidden, seq_num_layers=count_layers(state_dict, "seq_layers", 12), seq_num_heads=12, seq_max_length=seq_max_length, ms_vocab_size=ms_vocab_size, ms_hidden_size=ms_hidden_size, ms_num_layers=count_layers(state_dict, "ms_layers", 6), ms_num_heads=6, ms_max_length=150, struct_vocab_size=struct_vocab_size, struct_hidden_size=struct_hidden_size, struct_num_layers=count_layers(state_dict, "struct_layers", 8), struct_num_heads=8, struct_max_length=200, use_3d=use_3d, use_cross_attention=any(key.startswith("cross_attention.") for key in state_dict), use_cnn_frontend=any(key.startswith("seq_embeddings.conv_layers.") for key in state_dict), cnn_kernel_size=3, ) def load_bertose_backbone(checkpoint_path: Path, device: torch.device): MultimodalGlycanBERT, _, _ = import_bertose_classes() checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) state_dict = checkpoint.get("model_state_dict", checkpoint) state_dict = {key: value for key, value in state_dict.items() if not key.startswith("proj_head.")} config = build_config_from_state_dict(state_dict) model = MultimodalGlycanBERT(config) missing, unexpected = model.load_state_dict(state_dict, strict=False) missing_report = [key for key in missing if not key.startswith(("dist_proj.", "distance_head."))] unexpected_report = [key for key in unexpected if not key.startswith(("proj_head.", "distance_head."))] if missing_report: warnings.warn(f"{checkpoint_path.name}: {len(missing_report)} unexpected missing tensors") if unexpected_report: warnings.warn(f"{checkpoint_path.name}: {len(unexpected_report)} unexpected tensors") model._dist_none_printed = True model.to(device).eval() return model, config def load_tokenizer(vocab_path: Path): _, _, WURCSBPETokenizer = import_bertose_classes() return WURCSBPETokenizer(str(vocab_path)) def tensorize_tokenized(tokenized_rows: Sequence[Dict[str, Any]], device: torch.device) -> Dict[str, torch.Tensor]: keys = ["token_ids", "attention_mask", "residue_ids", "branch_depths", "linkage_types"] return {key: torch.tensor([row[key] for row in tokenized_rows], dtype=torch.long, device=device) for key in keys} def nonpad_tokens(tokenizer, token_ids: Sequence[int], attention_mask: Sequence[int]) -> List[str]: return [tokenizer.id_to_token.get(int(tid), "[UNK]") for tid, keep in zip(token_ids, attention_mask) if int(keep) == 1] class BertoseEmbedder: def __init__(self, checkpoint_path: Path, vocab_path: Path, device: torch.device): self.device = device self.tokenizer = load_tokenizer(vocab_path) self.model, self.config = load_bertose_backbone(checkpoint_path, device=device) @torch.no_grad() def _encode_hidden(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: hidden = self.model.seq_embeddings( batch["token_ids"], branch_depths=batch["branch_depths"], linkage_types=batch["linkage_types"], ) for layer in self.model.seq_layers: hidden = layer(hidden, batch["attention_mask"]) return hidden @torch.no_grad() def embed_batch(self, wurcs_list: Sequence[str], batch_size: int, pooling: str = "cls") -> np.ndarray: embeddings: List[np.ndarray] = [] for start in tqdm(range(0, len(wurcs_list), batch_size), desc="Embedding glycans"): chunk = list(wurcs_list[start : start + batch_size]) tokenized = [self.tokenizer.tokenize(wurcs, max_length=self.config.seq_max_length) for wurcs in chunk] batch = tensorize_tokenized(tokenized, self.device) hidden = self._encode_hidden(batch) if pooling == "cls": pooled = hidden[:, 0, :] elif pooling == "mean": mask = batch["attention_mask"].float().unsqueeze(-1) pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0) else: raise ValueError("pooling must be 'cls' or 'mean'") embeddings.append(pooled.detach().cpu().numpy().astype("float32")) return np.vstack(embeddings) def embed_single(self, wurcs: str, pooling: str = "cls") -> np.ndarray: return self.embed_batch([wurcs], batch_size=1, pooling=pooling)[0] class BertoseAmbiguityResolver: def __init__(self, checkpoint_path: Path, vocab_path: Path, ambiguity_path: Path, device: torch.device, threshold: float = 0.80, top_k: int = 5): self.device = device self.threshold = threshold self.top_k = top_k self.tokenizer = load_tokenizer(vocab_path) with open(ambiguity_path, encoding="utf-8") as handle: ambiguity = json.load(handle) self.ambiguous_ids = set(int(item) for item in ambiguity.get("ambiguous_ids", [])) self.ambiguous_ids.update( int(token_id) for token, token_id in self.tokenizer.token_to_id.items() if token not in {"[PAD]", "[UNK]", "[START]", "[END]", "[MASK]"} and ("?" in token or "x" in token or token.startswith("u")) ) self.special_ids = { int(token_id) for token, token_id in self.tokenizer.token_to_id.items() if token.startswith("[") and token.endswith("]") } self.invalid_prediction_ids = self.ambiguous_ids | self.special_ids self.mask_id = self.tokenizer.token_to_id.get("[MASK]", 4) self.model, self.config = load_bertose_backbone(checkpoint_path, device=device) def _ambiguous_positions(self, tokenized: Dict[str, Any]) -> List[int]: length = int(tokenized.get("length", sum(tokenized["attention_mask"]))) return [idx for idx, tid in enumerate(tokenized["token_ids"][:length]) if int(tid) in self.ambiguous_ids] @torch.no_grad() def _forward_logits(self, rows: Sequence[Dict[str, Any]], current_ids: Sequence[List[int]]) -> torch.Tensor: batch_size = len(rows) outputs = self.model( seq_token_ids=torch.tensor(current_ids, dtype=torch.long, device=self.device), seq_attention_mask=torch.tensor([row["attention_mask"] for row in rows], dtype=torch.long, device=self.device), seq_residue_ids=torch.tensor([row["residue_ids"] for row in rows], dtype=torch.long, device=self.device), seq_branch_depths=torch.tensor([row["branch_depths"] for row in rows], dtype=torch.long, device=self.device), seq_linkage_types=torch.tensor([row["linkage_types"] for row in rows], dtype=torch.long, device=self.device), ms_token_ids=torch.zeros(batch_size, 150, dtype=torch.long, device=self.device), ms_attention_mask=torch.zeros(batch_size, 150, dtype=torch.long, device=self.device), has_ms=torch.zeros(batch_size, dtype=torch.bool, device=self.device), struct_token_ids=torch.zeros(batch_size, 200, dtype=torch.long, device=self.device), struct_attention_mask=torch.zeros(batch_size, 200, dtype=torch.long, device=self.device), struct_residue_ids=torch.full((batch_size, 200), -1, dtype=torch.long, device=self.device), has_3d=torch.zeros(batch_size, dtype=torch.bool, device=self.device), return_dict=True, ) return outputs["seq_logits"] def resolve_batch(self, wurcs_list: Sequence[str], ids: Optional[Sequence[str]], iterations: int, batch_size: int) -> Tuple[pd.DataFrame, pd.DataFrame]: ids = list(ids) if ids is not None else [str(i) for i in range(len(wurcs_list))] summaries: List[Dict[str, Any]] = [] details: List[Dict[str, Any]] = [] for start in tqdm(range(0, len(wurcs_list), batch_size), desc="Resolving glycans"): chunk_wurcs = list(wurcs_list[start : start + batch_size]) chunk_ids = list(ids[start : start + batch_size]) tokenized_rows = [self.tokenizer.tokenize(wurcs, max_length=self.config.seq_max_length) for wurcs in chunk_wurcs] current_ids = [list(row["token_ids"]) for row in tokenized_rows] unresolved = [set(self._ambiguous_positions(row)) for row in tokenized_rows] initial_counts = [len(item) for item in unresolved] accepted_by_row: List[List[Dict[str, Any]]] = [[] for _ in tokenized_rows] for iteration in range(1, iterations + 1): masked_ids = [] for row_idx, row_ids in enumerate(current_ids): masked = list(row_ids) for pos in unresolved[row_idx]: masked[pos] = self.mask_id masked_ids.append(masked) logits = self._forward_logits(tokenized_rows, masked_ids) progress = 0 for row_idx, positions in enumerate(unresolved): for pos in sorted(list(positions)): probs = F.softmax(logits[row_idx, pos], dim=-1) top_probs, top_ids = torch.topk(probs, k=self.top_k) top_ids_list = [int(item) for item in top_ids.detach().cpu().tolist()] top_probs_list = [float(item) for item in top_probs.detach().cpu().tolist()] top_tokens = [self.tokenizer.id_to_token.get(item, "[UNK]") for item in top_ids_list] accepted_candidate = next( ( (candidate_id, candidate_prob) for candidate_id, candidate_prob in zip(top_ids_list, top_probs_list) if candidate_prob >= self.threshold and candidate_id not in self.invalid_prediction_ids ), None, ) detail = { "sample_id": chunk_ids[row_idx], "wurcs": chunk_wurcs[row_idx], "iteration": iteration, "position": pos, "original_id": int(tokenized_rows[row_idx]["token_ids"][pos]), "original_token": self.tokenizer.id_to_token.get(int(tokenized_rows[row_idx]["token_ids"][pos]), "[UNK]"), "top_ids": json.dumps(top_ids_list), "top_tokens": json.dumps(top_tokens), "top_probs": json.dumps(top_probs_list), "accepted": False, } if accepted_candidate is not None: accepted_id, accepted_prob = accepted_candidate current_ids[row_idx][pos] = accepted_id positions.remove(pos) progress += 1 detail.update( { "accepted": True, "resolved_id": accepted_id, "resolved_token": self.tokenizer.id_to_token.get(accepted_id, "[UNK]"), "confidence": accepted_prob, } ) accepted_by_row[row_idx].append(detail.copy()) details.append(detail) if progress == 0: break for row_idx, row in enumerate(tokenized_rows): summaries.append( { "sample_id": chunk_ids[row_idx], "wurcs": chunk_wurcs[row_idx], "initial_ambiguous_tokens": initial_counts[row_idx], "resolved_tokens": len(accepted_by_row[row_idx]), "remaining_ambiguous_tokens": len(unresolved[row_idx]), "final_token_sequence": " ".join(nonpad_tokens(self.tokenizer, current_ids[row_idx], row["attention_mask"])), "accepted_updates_json": json.dumps(accepted_by_row[row_idx]), } ) return pd.DataFrame(summaries), pd.DataFrame(details) def resolve_single(self, wurcs: str, iterations: int) -> Tuple[pd.DataFrame, pd.DataFrame]: return self.resolve_batch([wurcs], ids=["single"], iterations=iterations, batch_size=1) def read_wurcs_csv(path: Optional[Path], default: Sequence[str], column: str, n: int) -> List[str]: if path and path.exists(): df = pd.read_csv(path) if column not in df.columns: raise ValueError(f"{path} must contain column {column!r}") vals = df[column].dropna().astype(str).tolist() else: vals = list(default) vals = [item for item in vals if item.startswith("WURCS=")] if len(vals) < n: raise ValueError(f"Need at least {n} WURCS strings, found {len(vals)}") return vals[:n] def choose_affinose_pairs( csv_path: Optional[Path], protein_h5: Optional[Path], output_dir: Path, n: int, default_wurcs: Sequence[str], ) -> Tuple[pd.DataFrame, Path, str]: if protein_h5 and protein_h5.exists(): with h5py.File(protein_h5, "r") as handle: available = set(handle.keys()) if csv_path and csv_path.exists(): df = pd.read_csv(csv_path, usecols=lambda col: col in {"protein_id", "glycan_wurcs"}) df = df.dropna(subset=["protein_id", "glycan_wurcs"]).copy() df["protein_id"] = df["protein_id"].astype(str) df["glycan_wurcs"] = df["glycan_wurcs"].astype(str) df = df[df["protein_id"].isin(available) & df["glycan_wurcs"].str.startswith("WURCS=")] df = df.drop_duplicates(["protein_id", "glycan_wurcs"]) diverse = df.drop_duplicates("protein_id").head(n) if len(diverse) < n: extra = df[~df.index.isin(diverse.index)].head(n - len(diverse)) diverse = pd.concat([diverse, extra], ignore_index=False) pairs = diverse.head(n).copy() else: keys = sorted(available)[:n] pairs = pd.DataFrame({"protein_id": keys, "glycan_wurcs": list(default_wurcs)[:n]}) if len(pairs) < n: raise ValueError(f"Need {n} AFFINose pairs with available protein embeddings, found {len(pairs)}") subset_h5 = output_dir / "affinose_subset_esmc_embeddings.h5" with h5py.File(protein_h5, "r") as src, h5py.File(subset_h5, "w") as dst: for protein_id in pairs["protein_id"].astype(str).unique(): dst.create_dataset(protein_id, data=src[protein_id][:], compression="gzip") return pairs.reset_index(drop=True), subset_h5, "real_subset" synthetic_h5 = output_dir / "synthetic_esmc_embeddings.h5" rng = np.random.default_rng(123) protein_ids = [f"smoke_protein_{idx}" for idx in range(n)] with h5py.File(synthetic_h5, "w") as dst: for idx, protein_id in enumerate(protein_ids): dst.create_dataset(protein_id, data=rng.normal(size=(64 + idx, 960)).astype("float32")) pairs = pd.DataFrame({"protein_id": protein_ids, "glycan_wurcs": list(default_wurcs)[:n]}) return pairs, synthetic_h5, "synthetic" def write_embedding_csv(path: Path, sample_ids: Sequence[str], wurcs: Sequence[str], embeddings: np.ndarray) -> None: emb_cols = [f"emb_{idx:03d}" for idx in range(embeddings.shape[1])] df = pd.concat( [ pd.DataFrame({"sample_id": sample_ids, "wurcs": wurcs}), pd.DataFrame(embeddings, columns=emb_cols), ], axis=1, ) df.to_csv(path, index=False) def finite_or_raise(values: Sequence[float], name: str) -> None: arr = np.asarray(values, dtype="float64") if not np.isfinite(arr).all(): raise RuntimeError(f"{name} contains non-finite values") def main() -> None: parser = argparse.ArgumentParser(description="Validate public BERTose/AFFINose Hugging Face inference.") parser.add_argument("--output-dir", type=Path, default=Path("release_smoke_outputs")) parser.add_argument("--hf-cache-dir", type=Path, default=Path("hf_release_cache")) parser.add_argument("--device", default="auto") parser.add_argument("--embedding-csv", type=Path, default=None) parser.add_argument("--iar-csv", type=Path, default=None) parser.add_argument("--affinose-csv", type=Path, default=None) parser.add_argument("--protein-emb-h5", type=Path, default=None) parser.add_argument("--embedding-n", type=int, default=20) parser.add_argument("--iar-n", type=int, default=20) parser.add_argument("--affinose-n", type=int, default=5) parser.add_argument("--embedding-batch-size", type=int, default=20) parser.add_argument("--iar-batch-size", type=int, default=8) parser.add_argument("--affinose-batch-size", type=int, default=5) parser.add_argument("--iar-iterations", type=int, default=3) args = parser.parse_args() torch.set_grad_enabled(False) args.output_dir.mkdir(parents=True, exist_ok=True) device = get_device(args.device) summary: Dict[str, Any] = { "device": str(device), "cuda_available": torch.cuda.is_available(), "torch_version": torch.__version__, "repo_ids": REPOS, "timing_seconds": {}, } t0 = time.perf_counter() paths = download_repos(args.hf_cache_dir) summary["timing_seconds"]["download_repos"] = time.perf_counter() - t0 summary["repo_paths"] = {key: str(path) for key, path in paths.items()} add_source_paths(paths) embedding_wurcs = read_wurcs_csv(args.embedding_csv, DEFAULT_WURCS, "wurcs", args.embedding_n) iar_wurcs = read_wurcs_csv(args.iar_csv, DEFAULT_AMBIGUOUS_WURCS, "wurcs", args.iar_n) encoder_ckpt = paths["encoder"] / "checkpoints" / "bertose_glycan_encoder.pt" encoder_vocab = paths["encoder"] / "vocab" / "bpe_vocabulary.json" resolver_ckpt = paths["resolver"] / "checkpoints" / "bertose_iar_resolver.pt" resolver_vocab = paths["resolver"] / "vocab" / "bpe_vocabulary.json" ambiguity_path = paths["resolver"] / "vocab" / "bpe_ambiguity_tokens.json" t0 = time.perf_counter() embedder = BertoseEmbedder(encoder_ckpt, encoder_vocab, device) summary["timing_seconds"]["load_bertose_embedder"] = time.perf_counter() - t0 t0 = time.perf_counter() single_embedding = embedder.embed_single(embedding_wurcs[0]) summary["timing_seconds"]["bertose_embedding_single"] = time.perf_counter() - t0 if single_embedding.shape != (768,): raise RuntimeError(f"Expected single embedding shape (768,), got {single_embedding.shape}") finite_or_raise(single_embedding, "single_embedding") t0 = time.perf_counter() batch_embeddings = embedder.embed_batch(embedding_wurcs, batch_size=args.embedding_batch_size) elapsed = time.perf_counter() - t0 summary["timing_seconds"]["bertose_embedding_batch"] = elapsed summary["bertose_embedding_batch_shape"] = list(batch_embeddings.shape) summary["bertose_embedding_items_per_second"] = len(embedding_wurcs) / max(elapsed, 1e-9) if batch_embeddings.shape != (len(embedding_wurcs), 768): raise RuntimeError(f"Expected batch embedding shape ({len(embedding_wurcs)}, 768), got {batch_embeddings.shape}") finite_or_raise(batch_embeddings.ravel(), "batch_embeddings") write_embedding_csv(args.output_dir / "bertose_embeddings_20.csv", [f"glycan_{idx:02d}" for idx in range(len(embedding_wurcs))], embedding_wurcs, batch_embeddings) del embedder if device.type == "cuda": torch.cuda.empty_cache() t0 = time.perf_counter() resolver = BertoseAmbiguityResolver(resolver_ckpt, resolver_vocab, ambiguity_path, device) summary["timing_seconds"]["load_bertose_iar"] = time.perf_counter() - t0 t0 = time.perf_counter() single_summary, single_details = resolver.resolve_single(iar_wurcs[0], iterations=args.iar_iterations) summary["timing_seconds"]["bertose_iar_single"] = time.perf_counter() - t0 if single_summary.empty: raise RuntimeError("BERTose IAR single output is empty") single_summary.to_csv(args.output_dir / "bertose_iar_single_summary.csv", index=False) single_details.to_csv(args.output_dir / "bertose_iar_single_details.csv", index=False) t0 = time.perf_counter() iar_summary, iar_details = resolver.resolve_batch( iar_wurcs, ids=[f"ambiguous_{idx:02d}" for idx in range(len(iar_wurcs))], iterations=args.iar_iterations, batch_size=args.iar_batch_size, ) elapsed = time.perf_counter() - t0 summary["timing_seconds"]["bertose_iar_batch"] = elapsed summary["bertose_iar_batch_rows"] = len(iar_summary) summary["bertose_iar_detail_rows"] = len(iar_details) summary["bertose_iar_items_per_second"] = len(iar_wurcs) / max(elapsed, 1e-9) if len(iar_summary) != len(iar_wurcs): raise RuntimeError(f"Expected {len(iar_wurcs)} IAR summary rows, got {len(iar_summary)}") iar_summary.to_csv(args.output_dir / "bertose_iar_batch_summary.csv", index=False) iar_details.to_csv(args.output_dir / "bertose_iar_batch_details.csv", index=False) del resolver if device.type == "cuda": torch.cuda.empty_cache() pairs, protein_h5, protein_source = choose_affinose_pairs( args.affinose_csv, args.protein_emb_h5, args.output_dir, args.affinose_n, embedding_wurcs, ) pairs_path = args.output_dir / "affinose_pairs.csv" pairs.to_csv(pairs_path, index=False) summary["affinose_protein_embedding_source"] = protein_source summary["affinose_pairs"] = len(pairs) t0 = time.perf_counter() from affinose_inference import AffinosePredictor predictor = AffinosePredictor( checkpoint_path=str(paths["affinose"] / "checkpoints" / "affinose_interaction_model.pt"), bertose_checkpoint=str(encoder_ckpt), vocab_path=str(paths["affinose"] / "vocab" / "bpe_vocabulary.json"), protein_emb_path=str(protein_h5), device=str(device), ) summary["timing_seconds"]["load_affinose"] = time.perf_counter() - t0 t0 = time.perf_counter() first = pairs.iloc[0] single_score = predictor.predict_single(first["glycan_wurcs"], first["protein_id"]) summary["timing_seconds"]["affinose_single"] = time.perf_counter() - t0 finite_or_raise([single_score], "affinose_single_score") summary["affinose_single_score"] = float(single_score) t0 = time.perf_counter() scores = predictor.predict_batch( pairs["glycan_wurcs"].astype(str).tolist(), pairs["protein_id"].astype(str).tolist(), batch_size=args.affinose_batch_size, ) elapsed = time.perf_counter() - t0 finite_or_raise(scores, "affinose_batch_scores") summary["timing_seconds"]["affinose_batch"] = elapsed summary["affinose_items_per_second"] = len(scores) / max(elapsed, 1e-9) out_pairs = pairs.copy() out_pairs["affinose_score"] = [float(score) for score in scores] out_pairs.to_csv(args.output_dir / "affinose_predictions_5.csv", index=False) summary["status"] = "ok" summary_path = args.output_dir / "smoke_summary.json" summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()