| | |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import re |
| |
|
| | |
| | |
| | |
| | def clean_sequence(seq: str) -> str: |
| | seq = re.sub(r">.*\n", "", seq) |
| | seq = re.sub(r"\s+", "", seq) |
| | return seq.upper() |
| |
|
| |
|
| | def generate_epitopes(seq, window=16): |
| | return [ |
| | (seq[i:i+window], i+1, i+window) |
| | for i in range(len(seq) - window + 1) |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def esm2_embed_batch(seqs, tokenizer, model, device, batch_size=8): |
| | seqs = [" ".join(list(s)) for s in seqs] |
| | embeddings = [] |
| |
|
| | for i in range(0, len(seqs), batch_size): |
| | batch = seqs[i:i+batch_size] |
| |
|
| | tokens = tokenizer( |
| | batch, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | add_special_tokens=True |
| | ).to(device) |
| |
|
| | outputs = model(**tokens) |
| | hidden = outputs.last_hidden_state |
| | mask = tokens.attention_mask.unsqueeze(-1) |
| |
|
| | pooled = (hidden * mask).sum(1) / mask.sum(1) |
| | embeddings.append(pooled.cpu().numpy()) |
| |
|
| | return np.vstack(embeddings) |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def get_final_score( |
| | epitope, |
| | metadata_df, |
| | tokenizer, |
| | esm_model, |
| | device, |
| | model_s1, |
| | model_s2, |
| | encoder_s2, |
| | threshold_s1=0.0 |
| | ): |
| | FEATURE_ORDER = ['assay', 'method', 'state', 'disease'] |
| |
|
| | |
| | metadata_df = metadata_df[FEATURE_ORDER] |
| |
|
| | |
| | emb = esm2_embed_batch([epitope], tokenizer, esm_model, device) |
| | df_emb = pd.DataFrame(emb) |
| |
|
| | |
| | X_s1 = pd.concat([df_emb, metadata_df], axis=1) |
| | p1 = model_s1.predict_proba(X_s1)[0, 1] |
| |
|
| | if p1 < threshold_s1: |
| | return p1, None, None |
| |
|
| | |
| | X_cat_s2 = encoder_s2.transform(metadata_df) |
| | X_s2 = np.hstack([df_emb.values, X_cat_s2]) |
| |
|
| | p2 = model_s2.predict_proba(X_s2)[0, 1] |
| | final = 0.4 * p1 + 0.6 * p2 |
| |
|
| | return p1, p2, final |
| |
|
| |
|