""" AFFINose inference — standalone prediction pipeline Architecture: BERTose glycan encoding + ESM-C protein embeddings + cross-attention fusion. GLYCAN: WURCS -> BPE -> BERTose (frozen) -> [B, Lg, 768] -> proj -> [B, Lg, 512] PROTEIN: ESM-C per-residue -> [B, Lp, 960] -> proj -> [B, Lp, 512] | 2x CrossAttentionBlock(d=512, 8H, FFN=1024) | glycan_enriched -> mean(valid tokens) -> [B, 512] protein_enriched -> mean(valid residues) -> [B, 512] | concat -> [B, 1024] -> MLP -> score """ import argparse import json import logging import sys from pathlib import Path from typing import Dict, List, Optional, Tuple import h5py import numpy as np import pandas as pd import torch from affinose_model import AffinoseInteractionModel, load_bertose_encoder from affinose_dataset import load_bpe_tokenizer logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) class AffinosePredictor: """Standalone predictor for glycan-protein binding affinity.""" def __init__(self, checkpoint_path, bertose_checkpoint, vocab_path, protein_emb_path, device="auto", max_protein_length=1024): if device == "auto": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) logger.info(f"Using device: {self.device}") self.max_protein_length = max_protein_length logger.info(f"Loading BPE tokenizer from {vocab_path}") self.tokenizer = load_bpe_tokenizer(vocab_path) logger.info(f"Loading BERTose from {bertose_checkpoint}") bertose_config, seq_embeddings, seq_layers = load_bertose_encoder( bertose_checkpoint, freeze_layers=12) logger.info("Building AFFINose interaction model") self.model = AffinoseInteractionModel( seq_embeddings=seq_embeddings, seq_layers=seq_layers, glycan_dim=bertose_config.seq_hidden_size, protein_dim=960, shared_dim=512, num_heads=8, ffn_dim=1024, num_cross_layers=2, dropout=0.1, swe_slices=512, swe_ref_points=64, separate_swe=False, pooling_mode="mean", interaction_mode="concat", use_cross_attention=True, ) logger.info(f"Loading checkpoint from {checkpoint_path}") state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) missing, unexpected = self.model.load_state_dict(state_dict, strict=False) if missing: logger.warning("Missing checkpoint tensors: %d", len(missing)) if unexpected: logger.warning("Unexpected checkpoint tensors: %d", len(unexpected)) self.model.to(self.device) self.model.eval() logger.info("Model loaded and ready for inference") logger.info(f"Loading protein embeddings from {protein_emb_path}") self.protein_embs = {} with h5py.File(protein_emb_path, "r") as f: for key in f.keys(): emb = torch.from_numpy(f[key][:]).float() if emb.dim() == 1: emb = emb.unsqueeze(0) if emb.shape[0] > max_protein_length: emb = emb[:max_protein_length] self.protein_embs[key.replace("|", "/")] = emb logger.info(f" Loaded {len(self.protein_embs)} protein embeddings") def _tokenize_glycan(self, wurcs): tok = self.tokenizer.tokenize(wurcs, max_length=256) return {k: torch.tensor(tok[k], dtype=torch.long) for k in ["token_ids", "attention_mask", "branch_depths", "linkage_types"]} @torch.no_grad() def predict_single(self, wurcs, protein_id): if protein_id not in self.protein_embs: raise KeyError(f"Protein '{protein_id}' not found in embeddings.") tokens = self._tokenize_glycan(wurcs) protein_emb = self.protein_embs[protein_id] batch = { "token_ids": tokens["token_ids"].unsqueeze(0).to(self.device), "attention_mask": tokens["attention_mask"].unsqueeze(0).float().to(self.device), "branch_depths": tokens["branch_depths"].unsqueeze(0).to(self.device), "linkage_types": tokens["linkage_types"].unsqueeze(0).to(self.device), "protein_emb": protein_emb.unsqueeze(0).to(self.device), "protein_mask": torch.ones(1, protein_emb.shape[0]).to(self.device), } score = self.model( token_ids=batch["token_ids"], attention_mask=batch["attention_mask"], branch_depths=batch["branch_depths"], linkage_types=batch["linkage_types"], protein_emb=batch["protein_emb"], protein_mask=batch["protein_mask"], ) return score.item() @torch.no_grad() def predict_batch(self, wurcs_list, protein_ids, batch_size=32): all_scores = [] n = len(wurcs_list) for start in range(0, n, batch_size): end = min(start + batch_size, n) tokenized = [self._tokenize_glycan(w) for w in wurcs_list[start:end]] token_ids = torch.stack([t["token_ids"] for t in tokenized]).to(self.device) attn_mask = torch.stack([t["attention_mask"] for t in tokenized]).float().to(self.device) branch_d = torch.stack([t["branch_depths"] for t in tokenized]).to(self.device) link_t = torch.stack([t["linkage_types"] for t in tokenized]).to(self.device) missing = [pid for pid in protein_ids[start:end] if pid not in self.protein_embs] if missing: raise KeyError(f"{len(missing)} protein ids missing from embeddings, first: {missing[0]!r}") protein_embs = [self.protein_embs[pid] for pid in protein_ids[start:end]] max_len = max(e.shape[0] for e in protein_embs) dim = protein_embs[0].shape[1] prot_pad = torch.zeros(len(protein_embs), max_len, dim) prot_mask = torch.zeros(len(protein_embs), max_len) for i, emb in enumerate(protein_embs): prot_pad[i, :emb.shape[0]] = emb prot_mask[i, :emb.shape[0]] = 1.0 scores = self.model( token_ids=token_ids, attention_mask=attn_mask, branch_depths=branch_d, linkage_types=link_t, protein_emb=prot_pad.to(self.device), protein_mask=prot_mask.to(self.device), ) all_scores.extend(scores.cpu().tolist()) if (start // batch_size) % 10 == 0: logger.info(f" Predicted {end}/{n} pairs ({100*end/n:.1f}%)") return all_scores def main(): parser = argparse.ArgumentParser(description="AFFINose interaction inference") parser.add_argument("--checkpoint", required=True) parser.add_argument("--bertose_checkpoint", required=True) parser.add_argument("--vocab_path", required=True) parser.add_argument("--protein_emb_path", required=True) parser.add_argument("--device", default="auto") parser.add_argument("--wurcs", type=str, default=None) parser.add_argument("--protein_id", type=str, default=None) parser.add_argument("--input_csv", type=str, default=None) parser.add_argument("--output_csv", type=str, default=None) parser.add_argument("--batch_size", type=int, default=32) args = parser.parse_args() predictor = AffinosePredictor( checkpoint_path=args.checkpoint, bertose_checkpoint=args.bertose_checkpoint, vocab_path=args.vocab_path, protein_emb_path=args.protein_emb_path, device=args.device, ) if args.wurcs and args.protein_id: score = predictor.predict_single(args.wurcs, args.protein_id) print(f"\nPrediction: {score:.4f} (0=no binding, 1=strong)") elif args.input_csv: df = pd.read_csv(args.input_csv) scores = predictor.predict_batch( df["glycan_wurcs"].tolist(), df["protein_id"].tolist(), batch_size=args.batch_size) df["predicted_score"] = scores out = args.output_csv or args.input_csv.replace(".csv", "_predictions.csv") df.to_csv(out, index=False) logger.info(f"Saved {len(df)} predictions to {out}") else: parser.error("Provide --wurcs + --protein_id or --input_csv") if __name__ == "__main__": main()