cafa6_project / app.py
eloise54's picture
added evidence code and ancestors
965a3b9
############################################ INSTALL PACKAGES ############################################
import sys
import subprocess
import os
def install(package):
# Add --upgrade to force install the latest version
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package])
install("gradio>=3.44.0")
install("biopython==1.86")
install("cachetools==5.4.0")
install("mlflow==3.7.0")
HF_REPO_URL = "https://gitlab.com/nn_projects/cafa6_project"
CLONE_DIR = "/root/cafa6_project"
if not os.path.exists(CLONE_DIR):
os.system(f"git clone {HF_REPO_URL} {CLONE_DIR}")
os.chdir(CLONE_DIR)
############################################ DEFINE CONSTANTS ############################################
import gc
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm.auto import tqdm
import mlflow
import torch
import random
import requests
import re
from transformers import set_seed
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO
import gradio as gr
input_path = './'
data_dir = "numpy_dataset/"
test_embeddings_data = "prot_t5_embeddings_right_pooling_False_test_mini"
test_batch_size = 64
SEED = 42
MAX_SEQ_LEN = 512
HIDDEN_DIM = 1024
THRESH = 0.003
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
set_seed(SEED)
############################################ LOAD MODELS ############################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("USING DEVICE: ", device)
run_id = '9ee8f63638d0494ea20b63710c19a8b3'
SUBMISSION_INPUT_PATH = input_path + 'mlruns/11/' + run_id + '/artifacts/'
SUBMISSION_INPUT = SUBMISSION_INPUT_PATH+ 'submission.tsv'
OUTPUT_PATH = SUBMISSION_INPUT_PATH + '/diamond/'
models_uri = np.load(SUBMISSION_INPUT_PATH + "models_uri_C_fold_4.npy")
mlb_arrays_uri = np.load(SUBMISSION_INPUT_PATH + "mlb_arrays_uri_C_fold_4.npy")
#LOAD MODELS
MODELS = []
for uri in models_uri:
model = mlflow.pytorch.load_model(uri,
map_location=torch.device(device))
model.eval()
model.to(device)
MODELS.append(model)
#LOAD ONE HOT MLB ARRAYS
MLB_ARRAYS = [
np.load(uri, allow_pickle=True)
for uri in mlb_arrays_uri
]
#CREATE MATRIX ADAPTATORS TO MAP UNIQUE GO IDS TO EACH MODEL'S MLB ARRAY PREDICTION
concatenated_array = np.concatenate(MLB_ARRAYS)
unique_go_ids = np.array(list(set(concatenated_array)))
print(concatenated_array.shape)
print(unique_go_ids.shape)
matrix_adaptators = []
for n in range(0, len(MLB_ARRAYS)):
mlb_array = MLB_ARRAYS[n] # shape: (num_labels,)
unique_prot_ids_array = unique_go_ids # shape: (num_proteins,)
prob_matrix_adaptator = torch.zeros(len(unique_go_ids), len(mlb_array))
for i in range(0, len(unique_go_ids)):
for j in range(0, len(mlb_array)):
if unique_go_ids[i] == mlb_array[j]:
prob_matrix_adaptator[i, j] = 1.0
print(n, " " , prob_matrix_adaptator.shape)
matrix_adaptators.append(prob_matrix_adaptator.to(device))
############################################ UNIPROTKB PROT5 EMBEDDINGS ############################################
from transformers import T5Tokenizer, T5EncoderModel
model_type = "Rostlab/prot_t5_xl_uniref50"
tokenizer = T5Tokenizer.from_pretrained(model_type, do_lower_case=False, truncation_side = "right") #do not put to lower case, prot T5 needs upper case letters
protT5 = T5EncoderModel.from_pretrained(model_type, trust_remote_code=True).to(device)
max_sequence_len = 512 #prot_t5_xl pretraining done on max 512 seq len
batch_size = 64
SPECIAL_IDS = set(tokenizer.all_special_ids)
# Freeze params, inference only
protT5.eval()
for param in protT5.parameters():
param.requires_grad = False
def preprocess_sequence(seq: str) -> str:
seq = " ".join(seq)
seq = re.sub(r"[UZOB]", "X", seq)
return seq
def fetch_uniprot_sequence(uniprot_id: str) -> str:
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
r = requests.get(url, timeout=10)
if r.status_code != 200:
raise ValueError(f"UniProt ID '{uniprot_id}' not found")
fasta = r.text.splitlines()
return "".join(line for line in fasta if not line.startswith(">"))
def generate_embedding_from_uniprot(uniprot_id: str):
seq = fetch_uniprot_sequence(uniprot_id)
seq = preprocess_sequence(seq)
tokens = tokenizer(
seq,
return_tensors="pt",
truncation=True,
add_special_tokens=True,
padding="max_length",
max_length=MAX_SEQ_LEN,
)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protT5(**tokens)
raw_embeddings = outputs.last_hidden_state # (1, L, D)
input_ids = tokens["input_ids"]
mask_2d = tokens["attention_mask"].clone() # (1, L)
for sid in SPECIAL_IDS:
mask_2d[input_ids == sid] = 0
mask_3d = mask_2d.unsqueeze(-1).float() # (1, L, 1)
masked_embeddings = raw_embeddings * mask_3d # (1, L, D)
return masked_embeddings, mask_2d
############################################ PREDICTION CODE ############################################
def ensemble_predict(embedding, mask, topk=20):
scores = torch.zeros(1, len(unique_go_ids), device=device)
counts = torch.zeros_like(scores)
for model, adaptor in zip(MODELS, matrix_adaptators):
preds = model(embedding, mask) # (num_go, 1)
preds = preds.transpose(0, 1) # (1, num_go)
adapted = adaptor @ preds # (unique_go, 1)
adapted = adapted.T # (1, unique_go)
scores += adapted
counts += (torch.abs(adapted) > 1e-9).float()
scores /= torch.clamp(counts, min=1)
scores = scores.squeeze(0)
mask = scores > THRESH
scores = scores * mask.float()
idx = torch.argsort(scores, descending=True)[:topk]
return pd.DataFrame({
"GO_ID": unique_go_ids[idx.cpu().numpy()],
"Score": scores[idx].round(decimals=3).detach().cpu().numpy()
})
def predict(uniprot_id, topk):
embedding, mask = generate_embedding_from_uniprot(uniprot_id)
return ensemble_predict(embedding, mask, topk)
############################################ UNIPROTKB AND GRADIO UTILS ############################################
def fetch_human_uniprot_examples(n=500):
url = "https://rest.uniprot.org/uniprotkb/search"
params = {
"query": "organism_id:9606 AND reviewed:true",
"fields": "accession",
"format": "json",
"size": n
}
r = requests.get(url, params=params, timeout=10)
r.raise_for_status()
data = r.json()
return [e["primaryAccession"] for e in data["results"]]
def fetch_go_ancestors(go_id):
url = (
f"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/"
f"{go_id}/ancestors?relations=is_a,part_of"
)
try:
r = requests.get(url, headers={"Accept": "application/json"}, timeout=10)
r.raise_for_status()
except Exception as e:
print(f"Warning: failed to fetch ancestors for {go_id}: {e}")
return set()
data = r.json()
ancestors = set()
for term in data.get("results", []):
ancestors.update(term.get("ancestors", []))
return ancestors
CAFA6_EVIDENCE_CODES = {
"EXP", # Inferred from Experiment
"IDA", # Inferred from Direct Assay
"IPI", # Inferred from Physical Interaction
"IMP", # Inferred from Mutant Phenotype
"IGI", # Inferred from Genetic Interaction
"IEP", # Inferred from Expression Pattern
"TAS", # Traceable Author Statement
"IC" # Inferred by Curator
}
def fetch_quickgo_annotations_with_ancestors(uniprot_id):
direct_go_ids = set()
pageSize = 200
page = 1
while True:
url = (
f"https://www.ebi.ac.uk/QuickGO/services/annotation/search?"
f"geneProductId=UniProtKB:{uniprot_id}&limit={pageSize}&page={page}"
)
try:
response = requests.get(url, headers={"Accept": "application/json"}, timeout=10)
response.raise_for_status()
except Exception as e:
print(f"Warning: failed to fetch annotations: {e}")
break
data = response.json()
results = data.get("results", [])
if not results:
break
for item in results:
go_id = item.get("goId")
evidence = item.get("goEvidence") # fetch evidence code
if go_id and evidence in CAFA6_EVIDENCE_CODES:
direct_go_ids.add(go_id)
page += 1
# Expand with ancestors
all_go_ids = set(direct_go_ids)
for go_id in direct_go_ids:
ancestors = fetch_go_ancestors(go_id)
all_go_ids.update(ancestors)
return all_go_ids
def color_topk_predictions(pred_df, true_go_ids):
colors = []
for go_id in pred_df["GO_ID"]:
if go_id in true_go_ids:
colors.append("background-color: #d4edda") # green
else:
colors.append("background-color: #f8d7da") # red
pred_df["Color"] = colors
return pred_df
def predictions_to_html(pred_df):
html = "<div style='text-align: center;'>"
html += "<table border='1' style='border-collapse: collapse; margin: 0 auto;'>"
html += "<tr><th>GO_ID</th><th>Score</th></tr>"
for _, row in pred_df.iterrows():
color = row['Color']
go_id = row['GO_ID']
go_url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}"
html += f"<tr style='{color}'>"
html += (
f"<td>"
f"<a href='{go_url}' target='_blank' rel='noopener noreferrer'>"
f"{go_id}"
f"</a>"
f"</td>"
)
html += f"<td>{row['Score']}</td>"
html += "</tr>"
html += "</table></div>"
return html
############################################ GRADIO APP ############################################
markdown_information= r"""
## Trained on CAFA6 Protein Function Prediction Dataset
```
@misc{cafa-6-protein-function-prediction,
author = {Iddo Friedberg and Predrag Radivojac and Paul D Thomas and An Phan and M. Clara De Paolis Kaluza and Damiano Piovesan and Parnal Joshi and Chris Mungall and Martyna Plomecka and Walter Reade and María Cruz},
title = {CAFA 6 Protein Function Prediction},
year = {2025},
howpublished = {\url{https://kaggle.com/competitions/cafa-6-protein-function-prediction}},
note = {Kaggle}
}
```
## SPROF-GO
```
@article{10.1093/bib/bbad117,
author = {Yuan, Qianmu and Xie, Junjie and Xie, Jiancong and Zhao, Huiying and Yang, Yuedong},
title = "{Fast and accurate protein function prediction from sequence through pretrained language model and homology-based label diffusion}",
journal = {Briefings in Bioinformatics},
year = {2023},
month = {03},
issn = {1477-4054},
doi = {10.1093/bib/bbad117},
url = {https://doi.org/10.1093/bib/bbad117}
}
```
"""
def predict_with_uniprot_highlight_html(uniprot_id, topk):
embedding, mask = generate_embedding_from_uniprot(uniprot_id)
pred_df = ensemble_predict(embedding, mask, topk)
true_go_ids = fetch_quickgo_annotations_with_ancestors(uniprot_id)
colored_df = color_topk_predictions(pred_df, true_go_ids)
return predictions_to_html(colored_df)
HUMAN_EXAMPLES = fetch_human_uniprot_examples()
def format_human_examples_md(examples):
md = ""
md += " ".join(
f"[{acc}](https://www.ebi.ac.uk/QuickGO/annotations?geneProductId={acc}) /"
for acc in examples
)
return md
def fetch_human_examples_md():
examples = HUMAN_EXAMPLES.copy()
random.shuffle(examples)
return format_human_examples_md(examples[:40])
def make_title_md(uniprot_id):
return f"### Prediction Table of [{uniprot_id}](https://www.ebi.ac.uk/QuickGO/annotations?geneProductId={uniprot_id})"
with gr.Blocks() as demo:
gr.Markdown("# 🦠 [SPROF-GO](https://github.com/biomed-AI/SPROF-GO) Ensemble Trained on [CAFA6](https://www.kaggle.com/competitions/cafa-6-protein-function-prediction/overview)")
# ===================== Inputs =====================
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.Markdown("## Inference")
gr.Markdown("⚠️ No label diffusion for fast inference.")
uniprot_input = gr.Textbox(label="UniProtKB Protein ID", value="A1X283")
topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top-K GO terms")
run_btn = gr.Button("Predict")
with gr.Column(scale=1):
gr.Markdown("## Human Protein Examples")
human_examples_md_comp = gr.Markdown(format_human_examples_md(HUMAN_EXAMPLES[:50]))
example_btn = gr.Button("🔄 Fetch more examples")
example_btn.click(fetch_human_examples_md, outputs=human_examples_md_comp)
# ===================== Output =====================
gr.HTML("<hr style='margin:20px 0;'>") # horizontal divider
title_md = gr.Markdown(make_title_md("A1X283"))
html_output = gr.HTML(label="Predicted GO terms")
gr.Markdown(
"Top-K predictions colored green if predicted GO term is in "
"[QuickGO](https://www.ebi.ac.uk/QuickGO/annotations) annotations, "
"red otherwise"
)
gr.Markdown(" Ancestors are propagated and will be colored green if any descendant is in QuickGO annotations")
gr.Markdown(" I have only kept the following evidence codes: EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC to match CAFA6 evaluation")
gr.HTML("<hr style='margin:20px 0;'>") # horizontal divider
gr.Markdown(markdown_information)
run_btn.click(
fn=make_title_md,
inputs=uniprot_input,
outputs=title_md
)
run_btn.click(
fn=predict_with_uniprot_highlight_html,
inputs=[uniprot_input, topk_slider],
outputs=html_output
)
demo.load(
fn=lambda uid, k: (make_title_md(uid), predict_with_uniprot_highlight_html(uid, k)),
inputs=[uniprot_input, topk_slider],
outputs=[title_md, html_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)