# utils/inference_ESM2.py import numpy as np import pandas as pd import torch import re # ------------------------ # Sequence utils # ------------------------ def clean_sequence(seq: str) -> str: seq = re.sub(r">.*\n", "", seq) seq = re.sub(r"\s+", "", seq) return seq.upper() def generate_epitopes(seq, window=16): return [ (seq[i:i+window], i+1, i+window) for i in range(len(seq) - window + 1) ] # ------------------------ # ESM-2 embedding # ------------------------ @torch.no_grad() def esm2_embed_batch(seqs, tokenizer, model, device, batch_size=8): seqs = [" ".join(list(s)) for s in seqs] embeddings = [] for i in range(0, len(seqs), batch_size): batch = seqs[i:i+batch_size] tokens = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True ).to(device) outputs = model(**tokens) hidden = outputs.last_hidden_state mask = tokens.attention_mask.unsqueeze(-1) pooled = (hidden * mask).sum(1) / mask.sum(1) embeddings.append(pooled.cpu().numpy()) return np.vstack(embeddings) # ------------------------ # Cascade inference # ------------------------ @torch.no_grad() def get_final_score( epitope, metadata_df, tokenizer, esm_model, device, model_s1, model_s2, encoder_s2, threshold_s1=0.0 ): FEATURE_ORDER = ['assay', 'method', 'state', 'disease'] # 🔒 컬럼 순서 강제 (Streamlit rerun 방어) metadata_df = metadata_df[FEATURE_ORDER] # 1. embedding emb = esm2_embed_batch([epitope], tokenizer, esm_model, device) df_emb = pd.DataFrame(emb) # Stage 1 X_s1 = pd.concat([df_emb, metadata_df], axis=1) p1 = model_s1.predict_proba(X_s1)[0, 1] if p1 < threshold_s1: return p1, None, None # Stage 2 X_cat_s2 = encoder_s2.transform(metadata_df) X_s2 = np.hstack([df_emb.values, X_cat_s2]) p2 = model_s2.predict_proba(X_s2)[0, 1] final = 0.4 * p1 + 0.6 * p2 return p1, p2, final