############################################ INSTALL PACKAGES ############################################ import sys import subprocess import os def install(package): # Add --upgrade to force install the latest version subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package]) install("gradio>=3.44.0") install("biopython==1.86") install("cachetools==5.4.0") install("mlflow==3.7.0") HF_REPO_URL = "https://gitlab.com/nn_projects/cafa6_project" CLONE_DIR = "/root/cafa6_project" if not os.path.exists(CLONE_DIR): os.system(f"git clone {HF_REPO_URL} {CLONE_DIR}") os.chdir(CLONE_DIR) ############################################ DEFINE CONSTANTS ############################################ import gc import pandas as pd import numpy as np from collections import defaultdict from tqdm.auto import tqdm import mlflow import torch import random import requests import re from transformers import set_seed from torch.utils.data import Dataset, DataLoader from Bio import SeqIO import gradio as gr input_path = './' data_dir = "numpy_dataset/" test_embeddings_data = "prot_t5_embeddings_right_pooling_False_test_mini" test_batch_size = 64 SEED = 42 MAX_SEQ_LEN = 512 HIDDEN_DIM = 1024 THRESH = 0.003 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) set_seed(SEED) ############################################ LOAD MODELS ############################################ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("USING DEVICE: ", device) run_id = '9ee8f63638d0494ea20b63710c19a8b3' SUBMISSION_INPUT_PATH = input_path + 'mlruns/11/' + run_id + '/artifacts/' SUBMISSION_INPUT = SUBMISSION_INPUT_PATH+ 'submission.tsv' OUTPUT_PATH = SUBMISSION_INPUT_PATH + '/diamond/' models_uri = np.load(SUBMISSION_INPUT_PATH + "models_uri_C_fold_4.npy") mlb_arrays_uri = np.load(SUBMISSION_INPUT_PATH + "mlb_arrays_uri_C_fold_4.npy") #LOAD MODELS MODELS = [] for uri in models_uri: model = mlflow.pytorch.load_model(uri, map_location=torch.device(device)) model.eval() model.to(device) MODELS.append(model) #LOAD ONE HOT MLB ARRAYS MLB_ARRAYS = [ np.load(uri, allow_pickle=True) for uri in mlb_arrays_uri ] #CREATE MATRIX ADAPTATORS TO MAP UNIQUE GO IDS TO EACH MODEL'S MLB ARRAY PREDICTION concatenated_array = np.concatenate(MLB_ARRAYS) unique_go_ids = np.array(list(set(concatenated_array))) print(concatenated_array.shape) print(unique_go_ids.shape) matrix_adaptators = [] for n in range(0, len(MLB_ARRAYS)): mlb_array = MLB_ARRAYS[n] # shape: (num_labels,) unique_prot_ids_array = unique_go_ids # shape: (num_proteins,) prob_matrix_adaptator = torch.zeros(len(unique_go_ids), len(mlb_array)) for i in range(0, len(unique_go_ids)): for j in range(0, len(mlb_array)): if unique_go_ids[i] == mlb_array[j]: prob_matrix_adaptator[i, j] = 1.0 print(n, " " , prob_matrix_adaptator.shape) matrix_adaptators.append(prob_matrix_adaptator.to(device)) ############################################ UNIPROTKB PROT5 EMBEDDINGS ############################################ from transformers import T5Tokenizer, T5EncoderModel model_type = "Rostlab/prot_t5_xl_uniref50" tokenizer = T5Tokenizer.from_pretrained(model_type, do_lower_case=False, truncation_side = "right") #do not put to lower case, prot T5 needs upper case letters protT5 = T5EncoderModel.from_pretrained(model_type, trust_remote_code=True).to(device) max_sequence_len = 512 #prot_t5_xl pretraining done on max 512 seq len batch_size = 64 SPECIAL_IDS = set(tokenizer.all_special_ids) # Freeze params, inference only protT5.eval() for param in protT5.parameters(): param.requires_grad = False def preprocess_sequence(seq: str) -> str: seq = " ".join(seq) seq = re.sub(r"[UZOB]", "X", seq) return seq def fetch_uniprot_sequence(uniprot_id: str) -> str: url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" r = requests.get(url, timeout=10) if r.status_code != 200: raise ValueError(f"UniProt ID '{uniprot_id}' not found") fasta = r.text.splitlines() return "".join(line for line in fasta if not line.startswith(">")) def generate_embedding_from_uniprot(uniprot_id: str): seq = fetch_uniprot_sequence(uniprot_id) seq = preprocess_sequence(seq) tokens = tokenizer( seq, return_tensors="pt", truncation=True, add_special_tokens=True, padding="max_length", max_length=MAX_SEQ_LEN, ) tokens = {k: v.to(device) for k, v in tokens.items()} with torch.no_grad(): outputs = protT5(**tokens) raw_embeddings = outputs.last_hidden_state # (1, L, D) input_ids = tokens["input_ids"] mask_2d = tokens["attention_mask"].clone() # (1, L) for sid in SPECIAL_IDS: mask_2d[input_ids == sid] = 0 mask_3d = mask_2d.unsqueeze(-1).float() # (1, L, 1) masked_embeddings = raw_embeddings * mask_3d # (1, L, D) return masked_embeddings, mask_2d ############################################ PREDICTION CODE ############################################ def ensemble_predict(embedding, mask, topk=20): scores = torch.zeros(1, len(unique_go_ids), device=device) counts = torch.zeros_like(scores) for model, adaptor in zip(MODELS, matrix_adaptators): preds = model(embedding, mask) # (num_go, 1) preds = preds.transpose(0, 1) # (1, num_go) adapted = adaptor @ preds # (unique_go, 1) adapted = adapted.T # (1, unique_go) scores += adapted counts += (torch.abs(adapted) > 1e-9).float() scores /= torch.clamp(counts, min=1) scores = scores.squeeze(0) mask = scores > THRESH scores = scores * mask.float() idx = torch.argsort(scores, descending=True)[:topk] return pd.DataFrame({ "GO_ID": unique_go_ids[idx.cpu().numpy()], "Score": scores[idx].round(decimals=3).detach().cpu().numpy() }) def predict(uniprot_id, topk): embedding, mask = generate_embedding_from_uniprot(uniprot_id) return ensemble_predict(embedding, mask, topk) ############################################ UNIPROTKB AND GRADIO UTILS ############################################ def fetch_human_uniprot_examples(n=500): url = "https://rest.uniprot.org/uniprotkb/search" params = { "query": "organism_id:9606 AND reviewed:true", "fields": "accession", "format": "json", "size": n } r = requests.get(url, params=params, timeout=10) r.raise_for_status() data = r.json() return [e["primaryAccession"] for e in data["results"]] def fetch_go_ancestors(go_id): url = ( f"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/" f"{go_id}/ancestors?relations=is_a,part_of" ) try: r = requests.get(url, headers={"Accept": "application/json"}, timeout=10) r.raise_for_status() except Exception as e: print(f"Warning: failed to fetch ancestors for {go_id}: {e}") return set() data = r.json() ancestors = set() for term in data.get("results", []): ancestors.update(term.get("ancestors", [])) return ancestors CAFA6_EVIDENCE_CODES = { "EXP", # Inferred from Experiment "IDA", # Inferred from Direct Assay "IPI", # Inferred from Physical Interaction "IMP", # Inferred from Mutant Phenotype "IGI", # Inferred from Genetic Interaction "IEP", # Inferred from Expression Pattern "TAS", # Traceable Author Statement "IC" # Inferred by Curator } def fetch_quickgo_annotations_with_ancestors(uniprot_id): direct_go_ids = set() pageSize = 200 page = 1 while True: url = ( f"https://www.ebi.ac.uk/QuickGO/services/annotation/search?" f"geneProductId=UniProtKB:{uniprot_id}&limit={pageSize}&page={page}" ) try: response = requests.get(url, headers={"Accept": "application/json"}, timeout=10) response.raise_for_status() except Exception as e: print(f"Warning: failed to fetch annotations: {e}") break data = response.json() results = data.get("results", []) if not results: break for item in results: go_id = item.get("goId") evidence = item.get("goEvidence") # fetch evidence code if go_id and evidence in CAFA6_EVIDENCE_CODES: direct_go_ids.add(go_id) page += 1 # Expand with ancestors all_go_ids = set(direct_go_ids) for go_id in direct_go_ids: ancestors = fetch_go_ancestors(go_id) all_go_ids.update(ancestors) return all_go_ids def color_topk_predictions(pred_df, true_go_ids): colors = [] for go_id in pred_df["GO_ID"]: if go_id in true_go_ids: colors.append("background-color: #d4edda") # green else: colors.append("background-color: #f8d7da") # red pred_df["Color"] = colors return pred_df def predictions_to_html(pred_df): html = "
" html += "" html += "" for _, row in pred_df.iterrows(): color = row['Color'] go_id = row['GO_ID'] go_url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" html += f"" html += ( f"" ) html += f"" html += "" html += "
GO_IDScore
" f"" f"{go_id}" f"" f"{row['Score']}
" return html ############################################ GRADIO APP ############################################ markdown_information= r""" ## Trained on CAFA6 Protein Function Prediction Dataset ``` @misc{cafa-6-protein-function-prediction, author = {Iddo Friedberg and Predrag Radivojac and Paul D Thomas and An Phan and M. Clara De Paolis Kaluza and Damiano Piovesan and Parnal Joshi and Chris Mungall and Martyna Plomecka and Walter Reade and María Cruz}, title = {CAFA 6 Protein Function Prediction}, year = {2025}, howpublished = {\url{https://kaggle.com/competitions/cafa-6-protein-function-prediction}}, note = {Kaggle} } ``` ## SPROF-GO ``` @article{10.1093/bib/bbad117, author = {Yuan, Qianmu and Xie, Junjie and Xie, Jiancong and Zhao, Huiying and Yang, Yuedong}, title = "{Fast and accurate protein function prediction from sequence through pretrained language model and homology-based label diffusion}", journal = {Briefings in Bioinformatics}, year = {2023}, month = {03}, issn = {1477-4054}, doi = {10.1093/bib/bbad117}, url = {https://doi.org/10.1093/bib/bbad117} } ``` """ def predict_with_uniprot_highlight_html(uniprot_id, topk): embedding, mask = generate_embedding_from_uniprot(uniprot_id) pred_df = ensemble_predict(embedding, mask, topk) true_go_ids = fetch_quickgo_annotations_with_ancestors(uniprot_id) colored_df = color_topk_predictions(pred_df, true_go_ids) return predictions_to_html(colored_df) HUMAN_EXAMPLES = fetch_human_uniprot_examples() def format_human_examples_md(examples): md = "" md += " ".join( f"[{acc}](https://www.ebi.ac.uk/QuickGO/annotations?geneProductId={acc}) /" for acc in examples ) return md def fetch_human_examples_md(): examples = HUMAN_EXAMPLES.copy() random.shuffle(examples) return format_human_examples_md(examples[:40]) def make_title_md(uniprot_id): return f"### Prediction Table of [{uniprot_id}](https://www.ebi.ac.uk/QuickGO/annotations?geneProductId={uniprot_id})" with gr.Blocks() as demo: gr.Markdown("# 🦠 [SPROF-GO](https://github.com/biomed-AI/SPROF-GO) Ensemble Trained on [CAFA6](https://www.kaggle.com/competitions/cafa-6-protein-function-prediction/overview)") # ===================== Inputs ===================== with gr.Row(equal_height=True): with gr.Column(scale=1): gr.Markdown("## Inference") gr.Markdown("⚠️ No label diffusion for fast inference.") uniprot_input = gr.Textbox(label="UniProtKB Protein ID", value="A1X283") topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top-K GO terms") run_btn = gr.Button("Predict") with gr.Column(scale=1): gr.Markdown("## Human Protein Examples") human_examples_md_comp = gr.Markdown(format_human_examples_md(HUMAN_EXAMPLES[:50])) example_btn = gr.Button("🔄 Fetch more examples") example_btn.click(fetch_human_examples_md, outputs=human_examples_md_comp) # ===================== Output ===================== gr.HTML("
") # horizontal divider title_md = gr.Markdown(make_title_md("A1X283")) html_output = gr.HTML(label="Predicted GO terms") gr.Markdown( "Top-K predictions colored green if predicted GO term is in " "[QuickGO](https://www.ebi.ac.uk/QuickGO/annotations) annotations, " "red otherwise" ) gr.Markdown(" Ancestors are propagated and will be colored green if any descendant is in QuickGO annotations") gr.Markdown(" I have only kept the following evidence codes: EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC to match CAFA6 evaluation") gr.HTML("
") # horizontal divider gr.Markdown(markdown_information) run_btn.click( fn=make_title_md, inputs=uniprot_input, outputs=title_md ) run_btn.click( fn=predict_with_uniprot_highlight_html, inputs=[uniprot_input, topk_slider], outputs=html_output ) demo.load( fn=lambda uid, k: (make_title_md(uid), predict_with_uniprot_highlight_html(uid, k)), inputs=[uniprot_input, topk_slider], outputs=[title_md, html_output] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)