| """ |
| AFFINose inference — standalone prediction pipeline |
| |
| Architecture: BERTose glycan encoding + ESM-C protein embeddings + cross-attention fusion. |
| GLYCAN: WURCS -> BPE -> BERTose (frozen) -> [B, Lg, 768] -> proj -> [B, Lg, 512] |
| PROTEIN: ESM-C per-residue -> [B, Lp, 960] -> proj -> [B, Lp, 512] |
| | |
| 2x CrossAttentionBlock(d=512, 8H, FFN=1024) |
| | |
| glycan_enriched -> mean(valid tokens) -> [B, 512] |
| protein_enriched -> mean(valid residues) -> [B, 512] |
| | |
| concat -> [B, 1024] -> MLP -> score |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import h5py |
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| from affinose_model import AffinoseInteractionModel, load_bertose_encoder |
| from affinose_dataset import load_bpe_tokenizer |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AffinosePredictor: |
| """Standalone predictor for glycan-protein binding affinity.""" |
|
|
| def __init__(self, checkpoint_path, bertose_checkpoint, vocab_path, |
| protein_emb_path, device="auto", max_protein_length=1024): |
| if device == "auto": |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = torch.device(device) |
| logger.info(f"Using device: {self.device}") |
| self.max_protein_length = max_protein_length |
|
|
| logger.info(f"Loading BPE tokenizer from {vocab_path}") |
| self.tokenizer = load_bpe_tokenizer(vocab_path) |
|
|
| logger.info(f"Loading BERTose from {bertose_checkpoint}") |
| bertose_config, seq_embeddings, seq_layers = load_bertose_encoder( |
| bertose_checkpoint, freeze_layers=12) |
|
|
| logger.info("Building AFFINose interaction model") |
| self.model = AffinoseInteractionModel( |
| seq_embeddings=seq_embeddings, |
| seq_layers=seq_layers, |
| glycan_dim=bertose_config.seq_hidden_size, |
| protein_dim=960, shared_dim=512, num_heads=8, |
| ffn_dim=1024, num_cross_layers=2, dropout=0.1, |
| swe_slices=512, swe_ref_points=64, separate_swe=False, |
| pooling_mode="mean", interaction_mode="concat", |
| use_cross_attention=True, |
| ) |
|
|
| logger.info(f"Loading checkpoint from {checkpoint_path}") |
| state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| missing, unexpected = self.model.load_state_dict(state_dict, strict=False) |
| if missing: |
| logger.warning("Missing checkpoint tensors: %d", len(missing)) |
| if unexpected: |
| logger.warning("Unexpected checkpoint tensors: %d", len(unexpected)) |
| self.model.to(self.device) |
| self.model.eval() |
| logger.info("Model loaded and ready for inference") |
|
|
| logger.info(f"Loading protein embeddings from {protein_emb_path}") |
| self.protein_embs = {} |
| with h5py.File(protein_emb_path, "r") as f: |
| for key in f.keys(): |
| emb = torch.from_numpy(f[key][:]).float() |
| if emb.dim() == 1: |
| emb = emb.unsqueeze(0) |
| if emb.shape[0] > max_protein_length: |
| emb = emb[:max_protein_length] |
| self.protein_embs[key.replace("|", "/")] = emb |
| logger.info(f" Loaded {len(self.protein_embs)} protein embeddings") |
|
|
| def _tokenize_glycan(self, wurcs): |
| tok = self.tokenizer.tokenize(wurcs, max_length=256) |
| return {k: torch.tensor(tok[k], dtype=torch.long) |
| for k in ["token_ids", "attention_mask", "branch_depths", "linkage_types"]} |
|
|
| @torch.no_grad() |
| def predict_single(self, wurcs, protein_id): |
| if protein_id not in self.protein_embs: |
| raise KeyError(f"Protein '{protein_id}' not found in embeddings.") |
| tokens = self._tokenize_glycan(wurcs) |
| protein_emb = self.protein_embs[protein_id] |
| batch = { |
| "token_ids": tokens["token_ids"].unsqueeze(0).to(self.device), |
| "attention_mask": tokens["attention_mask"].unsqueeze(0).float().to(self.device), |
| "branch_depths": tokens["branch_depths"].unsqueeze(0).to(self.device), |
| "linkage_types": tokens["linkage_types"].unsqueeze(0).to(self.device), |
| "protein_emb": protein_emb.unsqueeze(0).to(self.device), |
| "protein_mask": torch.ones(1, protein_emb.shape[0]).to(self.device), |
| } |
| score = self.model( |
| token_ids=batch["token_ids"], attention_mask=batch["attention_mask"], |
| branch_depths=batch["branch_depths"], linkage_types=batch["linkage_types"], |
| protein_emb=batch["protein_emb"], protein_mask=batch["protein_mask"], |
| ) |
| return score.item() |
|
|
| @torch.no_grad() |
| def predict_batch(self, wurcs_list, protein_ids, batch_size=32): |
| all_scores = [] |
| n = len(wurcs_list) |
| for start in range(0, n, batch_size): |
| end = min(start + batch_size, n) |
| tokenized = [self._tokenize_glycan(w) for w in wurcs_list[start:end]] |
| token_ids = torch.stack([t["token_ids"] for t in tokenized]).to(self.device) |
| attn_mask = torch.stack([t["attention_mask"] for t in tokenized]).float().to(self.device) |
| branch_d = torch.stack([t["branch_depths"] for t in tokenized]).to(self.device) |
| link_t = torch.stack([t["linkage_types"] for t in tokenized]).to(self.device) |
| missing = [pid for pid in protein_ids[start:end] if pid not in self.protein_embs] |
| if missing: |
| raise KeyError(f"{len(missing)} protein ids missing from embeddings, first: {missing[0]!r}") |
| protein_embs = [self.protein_embs[pid] for pid in protein_ids[start:end]] |
| max_len = max(e.shape[0] for e in protein_embs) |
| dim = protein_embs[0].shape[1] |
| prot_pad = torch.zeros(len(protein_embs), max_len, dim) |
| prot_mask = torch.zeros(len(protein_embs), max_len) |
| for i, emb in enumerate(protein_embs): |
| prot_pad[i, :emb.shape[0]] = emb |
| prot_mask[i, :emb.shape[0]] = 1.0 |
| scores = self.model( |
| token_ids=token_ids, attention_mask=attn_mask, |
| branch_depths=branch_d, linkage_types=link_t, |
| protein_emb=prot_pad.to(self.device), protein_mask=prot_mask.to(self.device), |
| ) |
| all_scores.extend(scores.cpu().tolist()) |
| if (start // batch_size) % 10 == 0: |
| logger.info(f" Predicted {end}/{n} pairs ({100*end/n:.1f}%)") |
| return all_scores |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="AFFINose interaction inference") |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--bertose_checkpoint", required=True) |
| parser.add_argument("--vocab_path", required=True) |
| parser.add_argument("--protein_emb_path", required=True) |
| parser.add_argument("--device", default="auto") |
| parser.add_argument("--wurcs", type=str, default=None) |
| parser.add_argument("--protein_id", type=str, default=None) |
| parser.add_argument("--input_csv", type=str, default=None) |
| parser.add_argument("--output_csv", type=str, default=None) |
| parser.add_argument("--batch_size", type=int, default=32) |
| args = parser.parse_args() |
|
|
| predictor = AffinosePredictor( |
| checkpoint_path=args.checkpoint, bertose_checkpoint=args.bertose_checkpoint, |
| vocab_path=args.vocab_path, protein_emb_path=args.protein_emb_path, |
| device=args.device, |
| ) |
|
|
| if args.wurcs and args.protein_id: |
| score = predictor.predict_single(args.wurcs, args.protein_id) |
| print(f"\nPrediction: {score:.4f} (0=no binding, 1=strong)") |
| elif args.input_csv: |
| df = pd.read_csv(args.input_csv) |
| scores = predictor.predict_batch( |
| df["glycan_wurcs"].tolist(), df["protein_id"].tolist(), |
| batch_size=args.batch_size) |
| df["predicted_score"] = scores |
| out = args.output_csv or args.input_csv.replace(".csv", "_predictions.csv") |
| df.to_csv(out, index=False) |
| logger.info(f"Saved {len(df)} predictions to {out}") |
| else: |
| parser.error("Provide --wurcs + --protein_id or --input_csv") |
|
|
| if __name__ == "__main__": |
| main() |
|
|