# ============================================================================ # Full initialization, loader, HLA mapping helpers, and prediction functions. # Paste this file into your project as init.py. Importing this module will # initialize both MHC-I and MHC-II engines (ENGINE_MHC1, ENGINE_MHC2). # ============================================================================ import re import os import json from collections import defaultdict import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel from peft import PeftModel from huggingface_hub import hf_hub_download, login from datasets import load_dataset # OPTIM: #from sklearn.metrics import precision_recall_curve import pandas as pd import numpy as np from contextlib import nullcontext # --- CONFIGURATION --- TOKEN = os.getenv("HF_TOKEN") if TOKEN is None: raise ValueError("HF_TOKEN environment variable is not set.") login(TOKEN) DEFAULT_EL_THRESHOLD = 0.60 DEFAULT_BA_THRESHOLD = 0.60 MHC1_CONFIG = { "model": "O047/esm2_MHC-I_Reforged_Single", "mapping": "O047/MHC-I_HLA_Mapping", "ba_db": "O047/MHC-I_BA_Data", "el_db": "O047/MHC-I_EL_Data", "eval_db": "O047/MHC-I_EVAL", "reg_head": "regHead_MHC-I.pt", "clf_head": "clfHead_MHC-I.pt" } MHC2_CONFIG = { "model": "O047/esm2_MHC-II_Reforged_Single", "mapping": "O047/MHC-II_HLA_Mapping", "ba_db": "O047/MHC-II_BA_Data", "el_db": "O047/MHC-II_EL_Data", "eval_db": "O047/MHC-II_EVAL", "reg_head": "regHead_MHC-II.pt", "clf_head": "clfHead_MHC-II.pt" } # --- Advanced HLA parsing and mapping helpers (preserve all variants) --- _ALLELE_RE = re.compile(r"([A-Za-z0-9]+)[\*\-_:]?(\d{2,3})[:_]?(\d{2})?$") def _format_single_allele_token(token: str) -> str: """ Convert a single token like 'DRB1_0401', 'DRB10401', 'DRB1*04:01', 'DRB1-04-01' into canonical 'HLA-DRB1*04:01'. """ if token is None: return None s = str(token).strip() s = s.replace(" ", "").replace("/", "-") s = re.sub(r"[-/]+", "-", s) if s.upper().startswith("HLA-") and "*" in s and ":" in s: return s if s.startswith("HLA-") else "HLA-" + s.split("HLA-")[-1] m = _ALLELE_RE.match(s.replace("HLA-", "").replace("hla-", "")) if m: gene = m.group(1).upper() part1 = m.group(2) part2 = m.group(3) or "" if part2: formatted = f"HLA-{gene}*{part1}:{part2}" else: formatted = f"HLA-{gene}*{part1}" return formatted if "*" in s: left, right = s.split("*", 1) left = left.upper() right = right.replace("_", ":").replace("-", ":") if ":" not in right and len(right) >= 4: right = right[:2] + ":" + right[2:] return f"HLA-{left}*{right}" return s def _normalize_allele(a: str) -> str: """ Normalize allele or allele-pair strings into canonical lookup keys. Handles: - single alleles: 'DRB1_0401' -> 'HLA-DRB1*04:01' - chain pairs: 'DPA1_04_01_DPB1_85_01' -> 'HLA-DPA1*04:01-DPB1*85:01' - separators: '/', '_', '-', ' ' are tolerated """ if a is None: return None s = str(a).strip() if s == "": return s # Split on explicit chain separators (dash or slash) but keep order chain_tokens = re.split(r"[\/\-]+", s) formatted_tokens = [] for tok in chain_tokens: # split concatenated tokens heuristically subtoks = re.split(r"(?=[A-Za-z]+[0-9])", tok) subtoks = [st for st in subtoks if st] if len(subtoks) == 1: formatted_tokens.append(_format_single_allele_token(subtoks[0])) else: for st in subtoks: formatted_tokens.append(_format_single_allele_token(st)) canonical = "-".join(formatted_tokens) return canonical def build_hla_map_preserve_variants(map_ds): """ Build a mapping that preserves original allele strings and normalized canonical keys. Each original allele string and its normalized canonical form(s) map to the same pseudosequence. """ rows = list(map_ds) raw_count = len(rows) hla_map = {} groups = defaultdict(list) for r in rows: orig = r.get("allele") pseudo = r.get("pseudosequence") if orig is None: continue orig_str = str(orig).strip() norm = _normalize_allele(orig_str) groups[norm].append((orig_str, pseudo)) duplicate_groups = 0 for norm_key, entries in groups.items(): chosen_pseudo = None for orig, pseudo in entries: if pseudo and str(pseudo).strip(): chosen_pseudo = pseudo break if chosen_pseudo is None: chosen_pseudo = entries[0][1] if entries else None for orig, _ in entries: if orig: hla_map[orig] = chosen_pseudo if norm_key: hla_map[norm_key] = chosen_pseudo compact = norm_key.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if norm_key else None if compact: hla_map[compact] = chosen_pseudo if len(entries) > 1: duplicate_groups += 1 stats = { "raw_rows": raw_count, "registered_keys": len(hla_map), "normalized_groups": len(groups), "duplicate_groups": duplicate_groups } return hla_map, stats def resolve_allele_key(query, hla_map): """ Resolve a user-supplied allele string to a key present in hla_map. Tries: 1) exact original string 2) canonical normalized form 3) compact form (no punctuation) 4) case variants 5) substring match fallback """ if query is None: return None q = str(query).strip() if q in hla_map: return q norm = _normalize_allele(q) if norm and norm in hla_map: return norm compact = norm.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if norm else None if compact and compact in hla_map: return compact if q.upper() in hla_map: return q.upper() if q.lower() in hla_map: return q.lower() q_comp = q.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") for key in hla_map.keys(): key_comp = key.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if q_comp and q_comp in key_comp: return key return None # --- MODEL ARCHITECTURES --- class ProteinHead(nn.Module): """BA Head for Regression""" def __init__(self, input_dim): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, input_dim), nn.GELU(), nn.Dropout(0.2), nn.Linear(input_dim, input_dim // 2), nn.GELU(), nn.Dropout(0.2), nn.Linear(input_dim // 2, 1) ) def forward(self, x): return self.mlp(x) class ImprovedProteinHead(nn.Module): """EL Head for Classification""" def __init__(self, input_dim, use_scale=False, scale_factor=1.0, use_bias=True, bias_value=-2.92): super().__init__() self.attention = nn.Sequential( nn.Linear(input_dim, input_dim // 4), nn.Tanh(), nn.Linear(input_dim // 4, 1) ) self.mlp = nn.Sequential( nn.Linear(input_dim, input_dim), nn.LayerNorm(input_dim), nn.GELU(), nn.Dropout(0.5), nn.Linear(input_dim, input_dim), nn.LayerNorm(input_dim), nn.GELU(), nn.Dropout(0.5), nn.Linear(input_dim, 1) ) self.use_scale = use_scale self.scale_factor = scale_factor self.use_bias = use_bias self.bias_value = bias_value def forward(self, x): attn_logits = self.attention(x) attn_weights = torch.softmax(attn_logits, dim=1, dtype=torch.float32).to(x.dtype) pooled = (x * attn_weights).sum(dim=1) out = self.mlp(pooled) if self.use_scale: out = out * self.scale_factor if self.use_bias: out = out + self.bias_value return torch.clamp(out, min=-10.0, max=10.0) # --- THRESHOLD UTILITIES --- """ OPTIM: def compute_optimal_threshold(y_true, y_probs, default=0.6): try: if len(np.unique(y_true)) < 2: return default p, r, t = precision_recall_curve(y_true, y_probs) f1 = (2 * p * r) / (p + r + 1e-8) idx = np.argmax(f1) return t[idx] if idx < len(t) else default except Exception: return default def calculate_engine_thresholds(engine, eval_repo, sample_size=5000): print(f" [*] Calculating dynamic thresholds from {eval_repo}...") try: eval_df = load_dataset(eval_repo, split="test", token=TOKEN).to_pandas() if len(eval_df) > sample_size: eval_df = eval_df.sample(sample_size, random_state=42).reset_index(drop=True) if 'pseudosequence' in eval_df.columns: eval_df.rename(columns={'pseudosequence': 'allele'}, inplace=True) seqs = [] for _, row in eval_df.iterrows(): p = row["peptide"] a = row["allele"] c = row.get("context", "") seqs.append(f"{p} [SEP] {a} [SEP] {c}") ba_preds, el_probs = [], [] BATCH_SIZE = 512 device = engine.get("device", "cpu") use_cuda_amp = (device != "cpu") and torch.cuda.is_available() for i in range(0, len(seqs), BATCH_SIZE): batch = seqs[i:i+BATCH_SIZE] toks = engine['tokenizer'](batch, return_tensors="pt", padding=True, truncation=True, max_length=128) try: toks = toks.to(device) except Exception: toks = {k: v.to(device) for k, v in toks.items()} amp_ctx = torch.cuda.amp.autocast() if use_cuda_amp else nullcontext() with torch.no_grad(), amp_ctx: outputs = engine['model'](**toks) last_hidden = outputs.last_hidden_state ba_emb = last_hidden.mean(dim=1).to(dtype=torch.float32) ba_batch_preds = engine['regHead'](ba_emb).squeeze(-1).cpu().numpy() ba_preds.extend(np.asarray(ba_batch_preds).ravel().tolist()) el_logits = engine['clfHead'](last_hidden.to(dtype=torch.float32)).squeeze(-1).cpu().numpy() el_batch_probs = 1.0 / (1.0 + np.exp(-np.asarray(el_logits).ravel())) el_probs.extend(el_batch_probs.tolist()) y_true = eval_df["score"].values el_thresh = compute_optimal_threshold(y_true, el_probs, default=0.6) ba_thresh = compute_optimal_threshold(y_true, ba_preds, default=0.5) print(f" [>] Calculated EL Threshold: {el_thresh:.4f}") print(f" [>] Calculated BA Threshold: {ba_thresh:.4f}") return el_thresh, ba_thresh except Exception as e: print(f" [!] Threshold calculation failed ({e}). Defaulting to 0.6 for both.") return 0.6, 0.6 """ # --- ENGINE LOADER (uses preserved-variant HLA map) --- def load_inference_engine(config, device="cpu"): """Loads the model, heads, tokenizer, databases, and calculates thresholds.""" print(f"\n[*] Initializing Engine for {config['model']} on {device}...") # 1. Load HLA Mapping (preserve variants) print(" [>] Fetching HLA Registry...") map_ds = load_dataset(config['mapping'], split="train", token=TOKEN) # --- NEW: extract a plain copy of the allele column for frontend use --- allele_list = list(map_ds['allele']) # this is a detached list, not a reference # store it in the engine under a clear key frontend_alleles = pd.DataFrame({"Allele": allele_list}) frontend_alleles_dict = {allele: idx for idx, allele in enumerate(allele_list)} # continue with preprocessing for inference hla_map, hla_stats = build_hla_map_preserve_variants(map_ds) print(f" [>] HLA mapping rows: {hla_stats['raw_rows']}; registered lookup keys: {hla_stats['registered_keys']}; normalized groups: {hla_stats['normalized_groups']}; duplicate groups: {hla_stats['duplicate_groups']}") # 2. Determine Latest Step try: state_path = hf_hub_download(repo_id=config['model'], filename="latest_training_state.txt", token=TOKEN) with open(state_path, 'r') as f: step = json.load(f)['last_step'] print(f" [>] Auto-detected latest checkpoint: Step {step}") except Exception as e: print(f" [!] Could not auto-detect step ({e}). Exiting.") return None ckpt_folder = f"checkpoints/step_{step}" # 3. Load Tokenizer & Base print(" [>] Loading Base Model & Tokenizer (ESM-2 650M)...") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") if "[SEP]" not in tokenizer.get_vocab(): tokenizer.add_special_tokens({'sep_token': '[SEP]'}) dtype = torch.float16 if torch.cuda.is_available() else torch.float32 base = AutoModel.from_pretrained( "facebook/esm2_t33_650M_UR50D", torch_dtype=dtype ) base.resize_token_embeddings(len(tokenizer)) hidden_dim = base.config.hidden_size # 4. Load LoRA print(" [>] Loading LoRA Adapters...") model = PeftModel.from_pretrained(base, config['model'], subfolder=ckpt_folder, token=TOKEN).to(device) model.eval() # 5. Load Heads (filenames come from config) print(" [>] Loading Prediction Heads...") regHead = ProteinHead(hidden_dim).to(device).float() clfHead = ImprovedProteinHead(hidden_dim, use_bias=True, bias_value=-2.92).to(device).float() try: reg_path = hf_hub_download(repo_id=config['model'], subfolder=ckpt_folder, filename=config['reg_head'], token=TOKEN) clf_path = hf_hub_download(repo_id=config['model'], subfolder=ckpt_folder, filename=config['clf_head'], token=TOKEN) regHead_state = torch.load(reg_path, map_location=device) if 'mlp.6.weight' in regHead_state: regHead.load_state_dict(regHead_state, strict=False) else: if 'mlp.5.weight' in regHead_state: regHead_state['mlp.6.weight'] = regHead_state.pop('mlp.5.weight') regHead_state['mlp.6.bias'] = regHead_state.pop('mlp.5.bias') regHead.load_state_dict(regHead_state, strict=False) clfHead.load_state_dict(torch.load(clf_path, map_location=device), strict=False) regHead.eval() clfHead.eval() except Exception as e: print(f" [!] Error loading heads: {e}") return None """ OPTIM 1: # 6. Preload Training DBs for Novelty Flags print(" [>] Loading Training Databases for novelty checks...") try: ba_ds = load_dataset(config['ba_db'], split="train", token=TOKEN) el_ds = load_dataset(config['el_db'], split="train", token=TOKEN) ba_set = set(ba_ds['peptide']) el_set = set(el_ds['peptide']) except Exception as e: print(f" [!] Error loading DBs: {e}. Novelty checks disabled.") ba_set, el_set = set(), set() """ # 6. Novelty checks disabled for deployment print(" [>] Novelty database loading disabled.") ba_set = set() el_set = set() ########################################################## engine = { "model": model, "tokenizer": tokenizer, "regHead": regHead, "clfHead": clfHead, "hla_map": hla_map, "device": device, "hidden_dim": hidden_dim, "step": step, "ba_db": ba_set, "el_db": el_set, # --- NEW: add the raw allele list for frontend --- "frontend_alleles": frontend_alleles, "frontend_alleles_hash": frontend_alleles_dict } """ OPTIM 2: # thresholds etc... el_thresh, ba_thresh = calculate_engine_thresholds(engine, config['eval_db']) engine["el_threshold"] = el_thresh engine["ba_threshold"] = ba_thresh """ # ------------------------------------------------------------------ # THRESHOLD LOADING # ------------------------------------------------------------------ el_thresh = DEFAULT_EL_THRESHOLD ba_thresh = DEFAULT_BA_THRESHOLD try: threshold_path = hf_hub_download( repo_id=config['model'], subfolder=ckpt_folder, filename="thresholds.json", token=TOKEN ) with open(threshold_path, "r") as f: threshold_data = json.load(f) el_thresh = float( threshold_data.get("el_threshold", DEFAULT_EL_THRESHOLD) ) ba_thresh = float( threshold_data.get("ba_threshold", DEFAULT_BA_THRESHOLD) ) print("\n" + "=" * 80) print("THRESHOLDS LOADED FROM CHECKPOINT") print(f"MODEL : {config['model']}") print(f"EL THRESHOLD : {el_thresh:.6f}") print(f"BA THRESHOLD : {ba_thresh:.6f}") print("=" * 80 + "\n") except Exception as e: print("\n" + "=" * 80) print("WARNING: CHECKPOINT THRESHOLDS NOT FOUND") print(f"MODEL : {config['model']}") print(f"REASON: {e}") print(f"USING DEFAULT EL THRESHOLD = {DEFAULT_EL_THRESHOLD:.6f}") print(f"USING DEFAULT BA THRESHOLD = {DEFAULT_BA_THRESHOLD:.6f}") print("=" * 80 + "\n") engine["el_threshold"] = el_thresh engine["ba_threshold"] = ba_thresh # ============================================================ # THRESHOLD FINALIZATION REPORT (PER ENGINE) # ============================================================ source_label = "CHECKPOINT" if el_thresh == DEFAULT_EL_THRESHOLD and ba_thresh == DEFAULT_BA_THRESHOLD: source_label = "DEFAULT (FALLBACK)" elif el_thresh == DEFAULT_EL_THRESHOLD or ba_thresh == DEFAULT_BA_THRESHOLD: source_label = "PARTIAL (MIXED DEFAULT + CHECKPOINT)" print("\n" + "#" * 90) print("#" + " " * 88 + "#") print("# ENGINE THRESHOLD FINALIZATION REPORT #") print("#" + " " * 88 + "#") print("#" * 90) print(f"# MODEL : {config['model']}") print(f"# DEVICE : {device}") print(f"# SOURCE : {source_label}") print("#" + "-" * 88 + "#") print(f"# EL THRESHOLD : {el_thresh:.8f}") print(f"# BA THRESHOLD : {ba_thresh:.8f}") print("#" + "-" * 88 + "#") print("# STATUS SUMMARY") if source_label == "CHECKPOINT": print("# -> USING TRAINED CHECKPOINT THRESHOLDS") elif source_label == "DEFAULT (FALLBACK)": print("# -> USING DEFAULT THRESHOLDS (NO CHECKPOINT FOUND)") else: print("# -> MIXED CONFIGURATION DETECTED") print("#" * 90 + "\n") ########################################################## print(f"[*] {config['model']} Engine Initialization Complete.") return engine # --- PREDICTION (keeps all allele variants by default) --- # Set to True to deduplicate by pseudosequence (faster); False to keep all variants (traceable) DEDUP_BY_PSEUDO = False def _base_predict(peptides, alleles, contexts, engine, model_name="Model"): """Internal common prediction logic (robust, device-safe, single-forward).""" if engine is None: print(f"[!] {model_name} Engine not initialized.") return pd.DataFrame() # Normalize inputs if peptides is None or alleles is None: print("[!] Error: At least one peptide and one allele required.") return pd.DataFrame() if isinstance(peptides, str): peptides = [peptides] if isinstance(alleles, str): alleles = [alleles] if len(peptides) == 0 or len(alleles) == 0: print("[!] Error: At least one peptide and one allele required.") return pd.DataFrame() if contexts is None: contexts = [""] * len(peptides) elif isinstance(contexts, str): contexts = [contexts] * len(peptides) elif len(contexts) != len(peptides): print("[!] Warning: Context list length mismatch. Defaulting to empty contexts.") contexts = [""] * len(peptides) # 1. Resolve alleles -> pseudosequences (use resolver and preserve variants per DEDUP_BY_PSEUDO) valid_alleles = {} known_pseudos = set() for hla in alleles: matched_key = resolve_allele_key(hla, engine['hla_map']) if matched_key: pseudo = engine['hla_map'][matched_key] if DEDUP_BY_PSEUDO: if pseudo not in known_pseudos: valid_alleles[matched_key] = pseudo known_pseudos.add(pseudo) else: # preserve every matched variant (traceability) valid_alleles[matched_key] = pseudo if not valid_alleles: print(f"[!] Error: None of the provided alleles were found in the {model_name} registry.") return pd.DataFrame() # 2. Build cartesian product batch_data = [] for i, pep in enumerate(peptides): ctx = contexts[i] for allele, pseudo in valid_alleles.items(): batch_data.append({"Peptide": pep, "Allele": allele, "pseudo": pseudo, "Context": ctx}) df = pd.DataFrame(batch_data) if df.empty: return pd.DataFrame() results = [] BATCH_SIZE = 64 device = engine.get("device", "cpu") use_cuda_amp = (device != "cpu") and torch.cuda.is_available() print(f"[*] {model_name}: Processing {len(peptides)} peptides × {len(valid_alleles)} unique alleles = {len(df)} predictions") for start in range(0, len(df), BATCH_SIZE): batch = df.iloc[start:start + BATCH_SIZE].reset_index(drop=True) seqs = [f"{p} [SEP] {ps} [SEP] {c}" for p, ps, c in zip(batch["Peptide"], batch["pseudo"], batch["Context"])] toks = engine['tokenizer'](seqs, return_tensors="pt", padding=True, truncation=True, max_length=128) try: toks = toks.to(device) except Exception: toks = {k: v.to(device) for k, v in toks.items()} try: with torch.no_grad(): if use_cuda_amp: amp_ctx = torch.cuda.amp.autocast() else: amp_ctx = nullcontext() with amp_ctx: outputs = engine['model'](**toks) last_hidden = outputs.last_hidden_state # shape: (B, L, H) # BA: pooled mean over sequence ba_emb = last_hidden.mean(dim=1).to(dtype=torch.float32) ba_preds = engine['regHead'](ba_emb).squeeze(-1).cpu().numpy() ba_preds = np.asarray(ba_preds).ravel() # EL: classification head expects sequence input (attention inside head) el_logits = engine['clfHead'](last_hidden.to(dtype=torch.float32)).squeeze(-1).cpu().numpy() el_logits = np.asarray(el_logits).ravel() el_probs = 1.0 / (1.0 + np.exp(-el_logits)) except RuntimeError as e: print(f"[!] Inference error on batch starting at {start}: {e}") torch.cuda.empty_cache() n = len(batch) ba_preds = np.full(n, np.nan) el_probs = np.full(n, np.nan) batch_res = batch.copy() batch_res['BA_Score'] = ba_preds batch_res['EL_Prob'] = el_probs results.append(batch_res) if not results: return pd.DataFrame() final_df = pd.concat(results, ignore_index=True) # 4. Thresholding & postprocessing el_t = engine.get('el_threshold', DEFAULT_EL_THRESHOLD) ba_t = engine.get('ba_threshold', DEFAULT_BA_THRESHOLD) # handle NaNs safely before casting final_df['EL_Class'] = (final_df['EL_Prob'].fillna(-1) >= el_t).astype(int) final_df['BA_Class'] = (final_df['BA_Score'].fillna(-1) >= ba_t).astype(int) """ OPTIM 3: final_df['Seen_in_BA'] = final_df['Peptide'].isin(engine.get('ba_db', set())) final_df['Seen_in_EL'] = final_df['Peptide'].isin(engine.get('el_db', set())) """ final_df['Seen_in_BA'] = False final_df['Seen_in_EL'] = False ################# # Cleanup final_df = final_df.drop(columns=['pseudo']) final_df = final_df.sort_values(by="EL_Prob", ascending=False).reset_index(drop=True) print(f"[*] {model_name} Inference Complete.") return final_df def predict_mhc1(peptides: list, alleles: list, contexts: list = None): return _base_predict(peptides, alleles, contexts, ENGINE_MHC1, "MHC-I") def predict_mhc2(peptides: list, alleles: list, contexts: list = None): return _base_predict(peptides, alleles, contexts, ENGINE_MHC2, "MHC-II") # --- (EXECUTION) INITIALIZATION HERE --- print("\n" + "="*80) print("INITIALIZING DUAL INFERENCE ENGINES (ESM-2 650M)") print("="*80) ENGINE_MHC1 = load_inference_engine(MHC1_CONFIG) ENGINE_MHC2 = load_inference_engine(MHC2_CONFIG)