epitope / utils /inference_ESM2.py
yunuk0's picture
Update utils/inference_ESM2.py
f2fdddb verified
# utils/inference_ESM2.py
import numpy as np
import pandas as pd
import torch
import re
# ------------------------
# Sequence utils
# ------------------------
def clean_sequence(seq: str) -> str:
seq = re.sub(r">.*\n", "", seq)
seq = re.sub(r"\s+", "", seq)
return seq.upper()
def generate_epitopes(seq, window=16):
return [
(seq[i:i+window], i+1, i+window)
for i in range(len(seq) - window + 1)
]
# ------------------------
# ESM-2 embedding
# ------------------------
@torch.no_grad()
def esm2_embed_batch(seqs, tokenizer, model, device, batch_size=8):
seqs = [" ".join(list(s)) for s in seqs]
embeddings = []
for i in range(0, len(seqs), batch_size):
batch = seqs[i:i+batch_size]
tokens = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
).to(device)
outputs = model(**tokens)
hidden = outputs.last_hidden_state
mask = tokens.attention_mask.unsqueeze(-1)
pooled = (hidden * mask).sum(1) / mask.sum(1)
embeddings.append(pooled.cpu().numpy())
return np.vstack(embeddings)
# ------------------------
# Cascade inference
# ------------------------
@torch.no_grad()
def get_final_score(
epitope,
metadata_df,
tokenizer,
esm_model,
device,
model_s1,
model_s2,
encoder_s2,
threshold_s1=0.0
):
FEATURE_ORDER = ['assay', 'method', 'state', 'disease']
# ๐Ÿ”’ ์ปฌ๋Ÿผ ์ˆœ์„œ ๊ฐ•์ œ (Streamlit rerun ๋ฐฉ์–ด)
metadata_df = metadata_df[FEATURE_ORDER]
# 1. embedding
emb = esm2_embed_batch([epitope], tokenizer, esm_model, device)
df_emb = pd.DataFrame(emb)
# Stage 1
X_s1 = pd.concat([df_emb, metadata_df], axis=1)
p1 = model_s1.predict_proba(X_s1)[0, 1]
if p1 < threshold_s1:
return p1, None, None
# Stage 2
X_cat_s2 = encoder_s2.transform(metadata_df)
X_s2 = np.hstack([df_emb.values, X_cat_s2])
p2 = model_s2.predict_proba(X_s2)[0, 1]
final = 0.4 * p1 + 0.6 * p2
return p1, p2, final