Spaces:
Running
Running
| # ============================================================================ | |
| # Full initialization, loader, HLA mapping helpers, and prediction functions. | |
| # Paste this file into your project as init.py. Importing this module will | |
| # initialize both MHC-I and MHC-II engines (ENGINE_MHC1, ENGINE_MHC2). | |
| # ============================================================================ | |
| import re | |
| import os | |
| import json | |
| from collections import defaultdict | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import PeftModel | |
| from huggingface_hub import hf_hub_download, login | |
| from datasets import load_dataset | |
| # OPTIM: | |
| #from sklearn.metrics import precision_recall_curve | |
| import pandas as pd | |
| import numpy as np | |
| from contextlib import nullcontext | |
| # --- CONFIGURATION --- | |
| TOKEN = os.getenv("HF_TOKEN") | |
| if TOKEN is None: | |
| raise ValueError("HF_TOKEN environment variable is not set.") | |
| login(TOKEN) | |
| DEFAULT_EL_THRESHOLD = 0.60 | |
| DEFAULT_BA_THRESHOLD = 0.60 | |
| MHC1_CONFIG = { | |
| "model": "O047/esm2_MHC-I_Reforged_Single", | |
| "mapping": "O047/MHC-I_HLA_Mapping", | |
| "ba_db": "O047/MHC-I_BA_Data", | |
| "el_db": "O047/MHC-I_EL_Data", | |
| "eval_db": "O047/MHC-I_EVAL", | |
| "reg_head": "regHead_MHC-I.pt", | |
| "clf_head": "clfHead_MHC-I.pt" | |
| } | |
| MHC2_CONFIG = { | |
| "model": "O047/esm2_MHC-II_Reforged_Single", | |
| "mapping": "O047/MHC-II_HLA_Mapping", | |
| "ba_db": "O047/MHC-II_BA_Data", | |
| "el_db": "O047/MHC-II_EL_Data", | |
| "eval_db": "O047/MHC-II_EVAL", | |
| "reg_head": "regHead_MHC-II.pt", | |
| "clf_head": "clfHead_MHC-II.pt" | |
| } | |
| # --- Advanced HLA parsing and mapping helpers (preserve all variants) --- | |
| _ALLELE_RE = re.compile(r"([A-Za-z0-9]+)[\*\-_:]?(\d{2,3})[:_]?(\d{2})?$") | |
| def _format_single_allele_token(token: str) -> str: | |
| """ | |
| Convert a single token like 'DRB1_0401', 'DRB10401', 'DRB1*04:01', 'DRB1-04-01' | |
| into canonical 'HLA-DRB1*04:01'. | |
| """ | |
| if token is None: | |
| return None | |
| s = str(token).strip() | |
| s = s.replace(" ", "").replace("/", "-") | |
| s = re.sub(r"[-/]+", "-", s) | |
| if s.upper().startswith("HLA-") and "*" in s and ":" in s: | |
| return s if s.startswith("HLA-") else "HLA-" + s.split("HLA-")[-1] | |
| m = _ALLELE_RE.match(s.replace("HLA-", "").replace("hla-", "")) | |
| if m: | |
| gene = m.group(1).upper() | |
| part1 = m.group(2) | |
| part2 = m.group(3) or "" | |
| if part2: | |
| formatted = f"HLA-{gene}*{part1}:{part2}" | |
| else: | |
| formatted = f"HLA-{gene}*{part1}" | |
| return formatted | |
| if "*" in s: | |
| left, right = s.split("*", 1) | |
| left = left.upper() | |
| right = right.replace("_", ":").replace("-", ":") | |
| if ":" not in right and len(right) >= 4: | |
| right = right[:2] + ":" + right[2:] | |
| return f"HLA-{left}*{right}" | |
| return s | |
| def _normalize_allele(a: str) -> str: | |
| """ | |
| Normalize allele or allele-pair strings into canonical lookup keys. | |
| Handles: | |
| - single alleles: 'DRB1_0401' -> 'HLA-DRB1*04:01' | |
| - chain pairs: 'DPA1_04_01_DPB1_85_01' -> 'HLA-DPA1*04:01-DPB1*85:01' | |
| - separators: '/', '_', '-', ' ' are tolerated | |
| """ | |
| if a is None: | |
| return None | |
| s = str(a).strip() | |
| if s == "": | |
| return s | |
| # Split on explicit chain separators (dash or slash) but keep order | |
| chain_tokens = re.split(r"[\/\-]+", s) | |
| formatted_tokens = [] | |
| for tok in chain_tokens: | |
| # split concatenated tokens heuristically | |
| subtoks = re.split(r"(?=[A-Za-z]+[0-9])", tok) | |
| subtoks = [st for st in subtoks if st] | |
| if len(subtoks) == 1: | |
| formatted_tokens.append(_format_single_allele_token(subtoks[0])) | |
| else: | |
| for st in subtoks: | |
| formatted_tokens.append(_format_single_allele_token(st)) | |
| canonical = "-".join(formatted_tokens) | |
| return canonical | |
| def build_hla_map_preserve_variants(map_ds): | |
| """ | |
| Build a mapping that preserves original allele strings and normalized canonical keys. | |
| Each original allele string and its normalized canonical form(s) map to the same pseudosequence. | |
| """ | |
| rows = list(map_ds) | |
| raw_count = len(rows) | |
| hla_map = {} | |
| groups = defaultdict(list) | |
| for r in rows: | |
| orig = r.get("allele") | |
| pseudo = r.get("pseudosequence") | |
| if orig is None: | |
| continue | |
| orig_str = str(orig).strip() | |
| norm = _normalize_allele(orig_str) | |
| groups[norm].append((orig_str, pseudo)) | |
| duplicate_groups = 0 | |
| for norm_key, entries in groups.items(): | |
| chosen_pseudo = None | |
| for orig, pseudo in entries: | |
| if pseudo and str(pseudo).strip(): | |
| chosen_pseudo = pseudo | |
| break | |
| if chosen_pseudo is None: | |
| chosen_pseudo = entries[0][1] if entries else None | |
| for orig, _ in entries: | |
| if orig: | |
| hla_map[orig] = chosen_pseudo | |
| if norm_key: | |
| hla_map[norm_key] = chosen_pseudo | |
| compact = norm_key.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if norm_key else None | |
| if compact: | |
| hla_map[compact] = chosen_pseudo | |
| if len(entries) > 1: | |
| duplicate_groups += 1 | |
| stats = { | |
| "raw_rows": raw_count, | |
| "registered_keys": len(hla_map), | |
| "normalized_groups": len(groups), | |
| "duplicate_groups": duplicate_groups | |
| } | |
| return hla_map, stats | |
| def resolve_allele_key(query, hla_map): | |
| """ | |
| Resolve a user-supplied allele string to a key present in hla_map. | |
| Tries: | |
| 1) exact original string | |
| 2) canonical normalized form | |
| 3) compact form (no punctuation) | |
| 4) case variants | |
| 5) substring match fallback | |
| """ | |
| if query is None: | |
| return None | |
| q = str(query).strip() | |
| if q in hla_map: | |
| return q | |
| norm = _normalize_allele(q) | |
| if norm and norm in hla_map: | |
| return norm | |
| compact = norm.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if norm else None | |
| if compact and compact in hla_map: | |
| return compact | |
| if q.upper() in hla_map: | |
| return q.upper() | |
| if q.lower() in hla_map: | |
| return q.lower() | |
| q_comp = q.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") | |
| for key in hla_map.keys(): | |
| key_comp = key.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") | |
| if q_comp and q_comp in key_comp: | |
| return key | |
| return None | |
| # --- MODEL ARCHITECTURES --- | |
| class ProteinHead(nn.Module): | |
| """BA Head for Regression""" | |
| def __init__(self, input_dim): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, input_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(input_dim, input_dim // 2), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(input_dim // 2, 1) | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class ImprovedProteinHead(nn.Module): | |
| """EL Head for Classification""" | |
| def __init__(self, input_dim, use_scale=False, scale_factor=1.0, use_bias=True, bias_value=-2.92): | |
| super().__init__() | |
| self.attention = nn.Sequential( | |
| nn.Linear(input_dim, input_dim // 4), | |
| nn.Tanh(), | |
| nn.Linear(input_dim // 4, 1) | |
| ) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, input_dim), | |
| nn.LayerNorm(input_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(input_dim, input_dim), | |
| nn.LayerNorm(input_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(input_dim, 1) | |
| ) | |
| self.use_scale = use_scale | |
| self.scale_factor = scale_factor | |
| self.use_bias = use_bias | |
| self.bias_value = bias_value | |
| def forward(self, x): | |
| attn_logits = self.attention(x) | |
| attn_weights = torch.softmax(attn_logits, dim=1, dtype=torch.float32).to(x.dtype) | |
| pooled = (x * attn_weights).sum(dim=1) | |
| out = self.mlp(pooled) | |
| if self.use_scale: | |
| out = out * self.scale_factor | |
| if self.use_bias: | |
| out = out + self.bias_value | |
| return torch.clamp(out, min=-10.0, max=10.0) | |
| # --- THRESHOLD UTILITIES --- | |
| """ OPTIM: | |
| def compute_optimal_threshold(y_true, y_probs, default=0.6): | |
| try: | |
| if len(np.unique(y_true)) < 2: | |
| return default | |
| p, r, t = precision_recall_curve(y_true, y_probs) | |
| f1 = (2 * p * r) / (p + r + 1e-8) | |
| idx = np.argmax(f1) | |
| return t[idx] if idx < len(t) else default | |
| except Exception: | |
| return default | |
| def calculate_engine_thresholds(engine, eval_repo, sample_size=5000): | |
| print(f" [*] Calculating dynamic thresholds from {eval_repo}...") | |
| try: | |
| eval_df = load_dataset(eval_repo, split="test", token=TOKEN).to_pandas() | |
| if len(eval_df) > sample_size: | |
| eval_df = eval_df.sample(sample_size, random_state=42).reset_index(drop=True) | |
| if 'pseudosequence' in eval_df.columns: | |
| eval_df.rename(columns={'pseudosequence': 'allele'}, inplace=True) | |
| seqs = [] | |
| for _, row in eval_df.iterrows(): | |
| p = row["peptide"] | |
| a = row["allele"] | |
| c = row.get("context", "") | |
| seqs.append(f"{p} [SEP] {a} [SEP] {c}") | |
| ba_preds, el_probs = [], [] | |
| BATCH_SIZE = 512 | |
| device = engine.get("device", "cpu") | |
| use_cuda_amp = (device != "cpu") and torch.cuda.is_available() | |
| for i in range(0, len(seqs), BATCH_SIZE): | |
| batch = seqs[i:i+BATCH_SIZE] | |
| toks = engine['tokenizer'](batch, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| try: | |
| toks = toks.to(device) | |
| except Exception: | |
| toks = {k: v.to(device) for k, v in toks.items()} | |
| amp_ctx = torch.cuda.amp.autocast() if use_cuda_amp else nullcontext() | |
| with torch.no_grad(), amp_ctx: | |
| outputs = engine['model'](**toks) | |
| last_hidden = outputs.last_hidden_state | |
| ba_emb = last_hidden.mean(dim=1).to(dtype=torch.float32) | |
| ba_batch_preds = engine['regHead'](ba_emb).squeeze(-1).cpu().numpy() | |
| ba_preds.extend(np.asarray(ba_batch_preds).ravel().tolist()) | |
| el_logits = engine['clfHead'](last_hidden.to(dtype=torch.float32)).squeeze(-1).cpu().numpy() | |
| el_batch_probs = 1.0 / (1.0 + np.exp(-np.asarray(el_logits).ravel())) | |
| el_probs.extend(el_batch_probs.tolist()) | |
| y_true = eval_df["score"].values | |
| el_thresh = compute_optimal_threshold(y_true, el_probs, default=0.6) | |
| ba_thresh = compute_optimal_threshold(y_true, ba_preds, default=0.5) | |
| print(f" [>] Calculated EL Threshold: {el_thresh:.4f}") | |
| print(f" [>] Calculated BA Threshold: {ba_thresh:.4f}") | |
| return el_thresh, ba_thresh | |
| except Exception as e: | |
| print(f" [!] Threshold calculation failed ({e}). Defaulting to 0.6 for both.") | |
| return 0.6, 0.6 | |
| """ | |
| # --- ENGINE LOADER (uses preserved-variant HLA map) --- | |
| def load_inference_engine(config, device="cpu"): | |
| """Loads the model, heads, tokenizer, databases, and calculates thresholds.""" | |
| print(f"\n[*] Initializing Engine for {config['model']} on {device}...") | |
| # 1. Load HLA Mapping (preserve variants) | |
| print(" [>] Fetching HLA Registry...") | |
| map_ds = load_dataset(config['mapping'], split="train", token=TOKEN) | |
| # --- NEW: extract a plain copy of the allele column for frontend use --- | |
| allele_list = list(map_ds['allele']) # this is a detached list, not a reference | |
| # store it in the engine under a clear key | |
| frontend_alleles = pd.DataFrame({"Allele": allele_list}) | |
| frontend_alleles_dict = {allele: idx for idx, allele in enumerate(allele_list)} | |
| # continue with preprocessing for inference | |
| hla_map, hla_stats = build_hla_map_preserve_variants(map_ds) | |
| print(f" [>] HLA mapping rows: {hla_stats['raw_rows']}; registered lookup keys: {hla_stats['registered_keys']}; normalized groups: {hla_stats['normalized_groups']}; duplicate groups: {hla_stats['duplicate_groups']}") | |
| # 2. Determine Latest Step | |
| try: | |
| state_path = hf_hub_download(repo_id=config['model'], filename="latest_training_state.txt", token=TOKEN) | |
| with open(state_path, 'r') as f: | |
| step = json.load(f)['last_step'] | |
| print(f" [>] Auto-detected latest checkpoint: Step {step}") | |
| except Exception as e: | |
| print(f" [!] Could not auto-detect step ({e}). Exiting.") | |
| return None | |
| ckpt_folder = f"checkpoints/step_{step}" | |
| # 3. Load Tokenizer & Base | |
| print(" [>] Loading Base Model & Tokenizer (ESM-2 650M)...") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| if "[SEP]" not in tokenizer.get_vocab(): | |
| tokenizer.add_special_tokens({'sep_token': '[SEP]'}) | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| base = AutoModel.from_pretrained( | |
| "facebook/esm2_t33_650M_UR50D", | |
| torch_dtype=dtype | |
| ) | |
| base.resize_token_embeddings(len(tokenizer)) | |
| hidden_dim = base.config.hidden_size | |
| # 4. Load LoRA | |
| print(" [>] Loading LoRA Adapters...") | |
| model = PeftModel.from_pretrained(base, config['model'], subfolder=ckpt_folder, token=TOKEN).to(device) | |
| model.eval() | |
| # 5. Load Heads (filenames come from config) | |
| print(" [>] Loading Prediction Heads...") | |
| regHead = ProteinHead(hidden_dim).to(device).float() | |
| clfHead = ImprovedProteinHead(hidden_dim, use_bias=True, bias_value=-2.92).to(device).float() | |
| try: | |
| reg_path = hf_hub_download(repo_id=config['model'], subfolder=ckpt_folder, filename=config['reg_head'], token=TOKEN) | |
| clf_path = hf_hub_download(repo_id=config['model'], subfolder=ckpt_folder, filename=config['clf_head'], token=TOKEN) | |
| regHead_state = torch.load(reg_path, map_location=device) | |
| if 'mlp.6.weight' in regHead_state: | |
| regHead.load_state_dict(regHead_state, strict=False) | |
| else: | |
| if 'mlp.5.weight' in regHead_state: | |
| regHead_state['mlp.6.weight'] = regHead_state.pop('mlp.5.weight') | |
| regHead_state['mlp.6.bias'] = regHead_state.pop('mlp.5.bias') | |
| regHead.load_state_dict(regHead_state, strict=False) | |
| clfHead.load_state_dict(torch.load(clf_path, map_location=device), strict=False) | |
| regHead.eval() | |
| clfHead.eval() | |
| except Exception as e: | |
| print(f" [!] Error loading heads: {e}") | |
| return None | |
| """ OPTIM 1: | |
| # 6. Preload Training DBs for Novelty Flags | |
| print(" [>] Loading Training Databases for novelty checks...") | |
| try: | |
| ba_ds = load_dataset(config['ba_db'], split="train", token=TOKEN) | |
| el_ds = load_dataset(config['el_db'], split="train", token=TOKEN) | |
| ba_set = set(ba_ds['peptide']) | |
| el_set = set(el_ds['peptide']) | |
| except Exception as e: | |
| print(f" [!] Error loading DBs: {e}. Novelty checks disabled.") | |
| ba_set, el_set = set(), set() | |
| """ | |
| # 6. Novelty checks disabled for deployment | |
| print(" [>] Novelty database loading disabled.") | |
| ba_set = set() | |
| el_set = set() | |
| ########################################################## | |
| engine = { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "regHead": regHead, | |
| "clfHead": clfHead, | |
| "hla_map": hla_map, | |
| "device": device, | |
| "hidden_dim": hidden_dim, | |
| "step": step, | |
| "ba_db": ba_set, | |
| "el_db": el_set, | |
| # --- NEW: add the raw allele list for frontend --- | |
| "frontend_alleles": frontend_alleles, | |
| "frontend_alleles_hash": frontend_alleles_dict | |
| } | |
| """ OPTIM 2: | |
| # thresholds etc... | |
| el_thresh, ba_thresh = calculate_engine_thresholds(engine, config['eval_db']) | |
| engine["el_threshold"] = el_thresh | |
| engine["ba_threshold"] = ba_thresh | |
| """ | |
| # ------------------------------------------------------------------ | |
| # THRESHOLD LOADING | |
| # ------------------------------------------------------------------ | |
| el_thresh = DEFAULT_EL_THRESHOLD | |
| ba_thresh = DEFAULT_BA_THRESHOLD | |
| try: | |
| threshold_path = hf_hub_download( | |
| repo_id=config['model'], | |
| subfolder=ckpt_folder, | |
| filename="thresholds.json", | |
| token=TOKEN | |
| ) | |
| with open(threshold_path, "r") as f: | |
| threshold_data = json.load(f) | |
| el_thresh = float( | |
| threshold_data.get("el_threshold", DEFAULT_EL_THRESHOLD) | |
| ) | |
| ba_thresh = float( | |
| threshold_data.get("ba_threshold", DEFAULT_BA_THRESHOLD) | |
| ) | |
| print("\n" + "=" * 80) | |
| print("THRESHOLDS LOADED FROM CHECKPOINT") | |
| print(f"MODEL : {config['model']}") | |
| print(f"EL THRESHOLD : {el_thresh:.6f}") | |
| print(f"BA THRESHOLD : {ba_thresh:.6f}") | |
| print("=" * 80 + "\n") | |
| except Exception as e: | |
| print("\n" + "=" * 80) | |
| print("WARNING: CHECKPOINT THRESHOLDS NOT FOUND") | |
| print(f"MODEL : {config['model']}") | |
| print(f"REASON: {e}") | |
| print(f"USING DEFAULT EL THRESHOLD = {DEFAULT_EL_THRESHOLD:.6f}") | |
| print(f"USING DEFAULT BA THRESHOLD = {DEFAULT_BA_THRESHOLD:.6f}") | |
| print("=" * 80 + "\n") | |
| engine["el_threshold"] = el_thresh | |
| engine["ba_threshold"] = ba_thresh | |
| # ============================================================ | |
| # THRESHOLD FINALIZATION REPORT (PER ENGINE) | |
| # ============================================================ | |
| source_label = "CHECKPOINT" | |
| if el_thresh == DEFAULT_EL_THRESHOLD and ba_thresh == DEFAULT_BA_THRESHOLD: | |
| source_label = "DEFAULT (FALLBACK)" | |
| elif el_thresh == DEFAULT_EL_THRESHOLD or ba_thresh == DEFAULT_BA_THRESHOLD: | |
| source_label = "PARTIAL (MIXED DEFAULT + CHECKPOINT)" | |
| print("\n" + "#" * 90) | |
| print("#" + " " * 88 + "#") | |
| print("# ENGINE THRESHOLD FINALIZATION REPORT #") | |
| print("#" + " " * 88 + "#") | |
| print("#" * 90) | |
| print(f"# MODEL : {config['model']}") | |
| print(f"# DEVICE : {device}") | |
| print(f"# SOURCE : {source_label}") | |
| print("#" + "-" * 88 + "#") | |
| print(f"# EL THRESHOLD : {el_thresh:.8f}") | |
| print(f"# BA THRESHOLD : {ba_thresh:.8f}") | |
| print("#" + "-" * 88 + "#") | |
| print("# STATUS SUMMARY") | |
| if source_label == "CHECKPOINT": | |
| print("# -> USING TRAINED CHECKPOINT THRESHOLDS") | |
| elif source_label == "DEFAULT (FALLBACK)": | |
| print("# -> USING DEFAULT THRESHOLDS (NO CHECKPOINT FOUND)") | |
| else: | |
| print("# -> MIXED CONFIGURATION DETECTED") | |
| print("#" * 90 + "\n") | |
| ########################################################## | |
| print(f"[*] {config['model']} Engine Initialization Complete.") | |
| return engine | |
| # --- PREDICTION (keeps all allele variants by default) --- | |
| # Set to True to deduplicate by pseudosequence (faster); False to keep all variants (traceable) | |
| DEDUP_BY_PSEUDO = False | |
| def _base_predict(peptides, alleles, contexts, engine, model_name="Model"): | |
| """Internal common prediction logic (robust, device-safe, single-forward).""" | |
| if engine is None: | |
| print(f"[!] {model_name} Engine not initialized.") | |
| return pd.DataFrame() | |
| # Normalize inputs | |
| if peptides is None or alleles is None: | |
| print("[!] Error: At least one peptide and one allele required.") | |
| return pd.DataFrame() | |
| if isinstance(peptides, str): | |
| peptides = [peptides] | |
| if isinstance(alleles, str): | |
| alleles = [alleles] | |
| if len(peptides) == 0 or len(alleles) == 0: | |
| print("[!] Error: At least one peptide and one allele required.") | |
| return pd.DataFrame() | |
| if contexts is None: | |
| contexts = [""] * len(peptides) | |
| elif isinstance(contexts, str): | |
| contexts = [contexts] * len(peptides) | |
| elif len(contexts) != len(peptides): | |
| print("[!] Warning: Context list length mismatch. Defaulting to empty contexts.") | |
| contexts = [""] * len(peptides) | |
| # 1. Resolve alleles -> pseudosequences (use resolver and preserve variants per DEDUP_BY_PSEUDO) | |
| valid_alleles = {} | |
| known_pseudos = set() | |
| for hla in alleles: | |
| matched_key = resolve_allele_key(hla, engine['hla_map']) | |
| if matched_key: | |
| pseudo = engine['hla_map'][matched_key] | |
| if DEDUP_BY_PSEUDO: | |
| if pseudo not in known_pseudos: | |
| valid_alleles[matched_key] = pseudo | |
| known_pseudos.add(pseudo) | |
| else: | |
| # preserve every matched variant (traceability) | |
| valid_alleles[matched_key] = pseudo | |
| if not valid_alleles: | |
| print(f"[!] Error: None of the provided alleles were found in the {model_name} registry.") | |
| return pd.DataFrame() | |
| # 2. Build cartesian product | |
| batch_data = [] | |
| for i, pep in enumerate(peptides): | |
| ctx = contexts[i] | |
| for allele, pseudo in valid_alleles.items(): | |
| batch_data.append({"Peptide": pep, "Allele": allele, "pseudo": pseudo, "Context": ctx}) | |
| df = pd.DataFrame(batch_data) | |
| if df.empty: | |
| return pd.DataFrame() | |
| results = [] | |
| BATCH_SIZE = 64 | |
| device = engine.get("device", "cpu") | |
| use_cuda_amp = (device != "cpu") and torch.cuda.is_available() | |
| print(f"[*] {model_name}: Processing {len(peptides)} peptides × {len(valid_alleles)} unique alleles = {len(df)} predictions") | |
| for start in range(0, len(df), BATCH_SIZE): | |
| batch = df.iloc[start:start + BATCH_SIZE].reset_index(drop=True) | |
| seqs = [f"{p} [SEP] {ps} [SEP] {c}" for p, ps, c in zip(batch["Peptide"], batch["pseudo"], batch["Context"])] | |
| toks = engine['tokenizer'](seqs, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| try: | |
| toks = toks.to(device) | |
| except Exception: | |
| toks = {k: v.to(device) for k, v in toks.items()} | |
| try: | |
| with torch.no_grad(): | |
| if use_cuda_amp: | |
| amp_ctx = torch.cuda.amp.autocast() | |
| else: | |
| amp_ctx = nullcontext() | |
| with amp_ctx: | |
| outputs = engine['model'](**toks) | |
| last_hidden = outputs.last_hidden_state # shape: (B, L, H) | |
| # BA: pooled mean over sequence | |
| ba_emb = last_hidden.mean(dim=1).to(dtype=torch.float32) | |
| ba_preds = engine['regHead'](ba_emb).squeeze(-1).cpu().numpy() | |
| ba_preds = np.asarray(ba_preds).ravel() | |
| # EL: classification head expects sequence input (attention inside head) | |
| el_logits = engine['clfHead'](last_hidden.to(dtype=torch.float32)).squeeze(-1).cpu().numpy() | |
| el_logits = np.asarray(el_logits).ravel() | |
| el_probs = 1.0 / (1.0 + np.exp(-el_logits)) | |
| except RuntimeError as e: | |
| print(f"[!] Inference error on batch starting at {start}: {e}") | |
| torch.cuda.empty_cache() | |
| n = len(batch) | |
| ba_preds = np.full(n, np.nan) | |
| el_probs = np.full(n, np.nan) | |
| batch_res = batch.copy() | |
| batch_res['BA_Score'] = ba_preds | |
| batch_res['EL_Prob'] = el_probs | |
| results.append(batch_res) | |
| if not results: | |
| return pd.DataFrame() | |
| final_df = pd.concat(results, ignore_index=True) | |
| # 4. Thresholding & postprocessing | |
| el_t = engine.get('el_threshold', DEFAULT_EL_THRESHOLD) | |
| ba_t = engine.get('ba_threshold', DEFAULT_BA_THRESHOLD) | |
| # handle NaNs safely before casting | |
| final_df['EL_Class'] = (final_df['EL_Prob'].fillna(-1) >= el_t).astype(int) | |
| final_df['BA_Class'] = (final_df['BA_Score'].fillna(-1) >= ba_t).astype(int) | |
| """ OPTIM 3: | |
| final_df['Seen_in_BA'] = final_df['Peptide'].isin(engine.get('ba_db', set())) | |
| final_df['Seen_in_EL'] = final_df['Peptide'].isin(engine.get('el_db', set())) | |
| """ | |
| final_df['Seen_in_BA'] = False | |
| final_df['Seen_in_EL'] = False | |
| ################# | |
| # Cleanup | |
| final_df = final_df.drop(columns=['pseudo']) | |
| final_df = final_df.sort_values(by="EL_Prob", ascending=False).reset_index(drop=True) | |
| print(f"[*] {model_name} Inference Complete.") | |
| return final_df | |
| def predict_mhc1(peptides: list, alleles: list, contexts: list = None): | |
| return _base_predict(peptides, alleles, contexts, ENGINE_MHC1, "MHC-I") | |
| def predict_mhc2(peptides: list, alleles: list, contexts: list = None): | |
| return _base_predict(peptides, alleles, contexts, ENGINE_MHC2, "MHC-II") | |
| # --- (EXECUTION) INITIALIZATION HERE --- | |
| print("\n" + "="*80) | |
| print("INITIALIZING DUAL INFERENCE ENGINES (ESM-2 650M)") | |
| print("="*80) | |
| ENGINE_MHC1 = load_inference_engine(MHC1_CONFIG) | |
| ENGINE_MHC2 = load_inference_engine(MHC2_CONFIG) | |