Spaces:
Running
Running
| ############################################ INSTALL PACKAGES ############################################ | |
| import sys | |
| import subprocess | |
| import os | |
| def install(package): | |
| # Add --upgrade to force install the latest version | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package]) | |
| install("gradio>=3.44.0") | |
| install("biopython==1.86") | |
| install("cachetools==5.4.0") | |
| install("mlflow==3.7.0") | |
| HF_REPO_URL = "https://gitlab.com/nn_projects/cafa6_project" | |
| CLONE_DIR = "/root/cafa6_project" | |
| if not os.path.exists(CLONE_DIR): | |
| os.system(f"git clone {HF_REPO_URL} {CLONE_DIR}") | |
| os.chdir(CLONE_DIR) | |
| ############################################ DEFINE CONSTANTS ############################################ | |
| import gc | |
| import pandas as pd | |
| import numpy as np | |
| from collections import defaultdict | |
| from tqdm.auto import tqdm | |
| import mlflow | |
| import torch | |
| import random | |
| import requests | |
| import re | |
| from transformers import set_seed | |
| from torch.utils.data import Dataset, DataLoader | |
| from Bio import SeqIO | |
| import gradio as gr | |
| input_path = './' | |
| data_dir = "numpy_dataset/" | |
| test_embeddings_data = "prot_t5_embeddings_right_pooling_False_test_mini" | |
| test_batch_size = 64 | |
| SEED = 42 | |
| MAX_SEQ_LEN = 512 | |
| HIDDEN_DIM = 1024 | |
| THRESH = 0.003 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| torch.cuda.manual_seed(SEED) | |
| set_seed(SEED) | |
| ############################################ LOAD MODELS ############################################ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("USING DEVICE: ", device) | |
| run_id = '9ee8f63638d0494ea20b63710c19a8b3' | |
| SUBMISSION_INPUT_PATH = input_path + 'mlruns/11/' + run_id + '/artifacts/' | |
| SUBMISSION_INPUT = SUBMISSION_INPUT_PATH+ 'submission.tsv' | |
| OUTPUT_PATH = SUBMISSION_INPUT_PATH + '/diamond/' | |
| models_uri = np.load(SUBMISSION_INPUT_PATH + "models_uri_C_fold_4.npy") | |
| mlb_arrays_uri = np.load(SUBMISSION_INPUT_PATH + "mlb_arrays_uri_C_fold_4.npy") | |
| #LOAD MODELS | |
| MODELS = [] | |
| for uri in models_uri: | |
| model = mlflow.pytorch.load_model(uri, | |
| map_location=torch.device(device)) | |
| model.eval() | |
| model.to(device) | |
| MODELS.append(model) | |
| #LOAD ONE HOT MLB ARRAYS | |
| MLB_ARRAYS = [ | |
| np.load(uri, allow_pickle=True) | |
| for uri in mlb_arrays_uri | |
| ] | |
| #CREATE MATRIX ADAPTATORS TO MAP UNIQUE GO IDS TO EACH MODEL'S MLB ARRAY PREDICTION | |
| concatenated_array = np.concatenate(MLB_ARRAYS) | |
| unique_go_ids = np.array(list(set(concatenated_array))) | |
| print(concatenated_array.shape) | |
| print(unique_go_ids.shape) | |
| matrix_adaptators = [] | |
| for n in range(0, len(MLB_ARRAYS)): | |
| mlb_array = MLB_ARRAYS[n] # shape: (num_labels,) | |
| unique_prot_ids_array = unique_go_ids # shape: (num_proteins,) | |
| prob_matrix_adaptator = torch.zeros(len(unique_go_ids), len(mlb_array)) | |
| for i in range(0, len(unique_go_ids)): | |
| for j in range(0, len(mlb_array)): | |
| if unique_go_ids[i] == mlb_array[j]: | |
| prob_matrix_adaptator[i, j] = 1.0 | |
| print(n, " " , prob_matrix_adaptator.shape) | |
| matrix_adaptators.append(prob_matrix_adaptator.to(device)) | |
| ############################################ UNIPROTKB PROT5 EMBEDDINGS ############################################ | |
| from transformers import T5Tokenizer, T5EncoderModel | |
| model_type = "Rostlab/prot_t5_xl_uniref50" | |
| tokenizer = T5Tokenizer.from_pretrained(model_type, do_lower_case=False, truncation_side = "right") #do not put to lower case, prot T5 needs upper case letters | |
| protT5 = T5EncoderModel.from_pretrained(model_type, trust_remote_code=True).to(device) | |
| max_sequence_len = 512 #prot_t5_xl pretraining done on max 512 seq len | |
| batch_size = 64 | |
| SPECIAL_IDS = set(tokenizer.all_special_ids) | |
| # Freeze params, inference only | |
| protT5.eval() | |
| for param in protT5.parameters(): | |
| param.requires_grad = False | |
| def preprocess_sequence(seq: str) -> str: | |
| seq = " ".join(seq) | |
| seq = re.sub(r"[UZOB]", "X", seq) | |
| return seq | |
| def fetch_uniprot_sequence(uniprot_id: str) -> str: | |
| url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" | |
| r = requests.get(url, timeout=10) | |
| if r.status_code != 200: | |
| raise ValueError(f"UniProt ID '{uniprot_id}' not found") | |
| fasta = r.text.splitlines() | |
| return "".join(line for line in fasta if not line.startswith(">")) | |
| def generate_embedding_from_uniprot(uniprot_id: str): | |
| seq = fetch_uniprot_sequence(uniprot_id) | |
| seq = preprocess_sequence(seq) | |
| tokens = tokenizer( | |
| seq, | |
| return_tensors="pt", | |
| truncation=True, | |
| add_special_tokens=True, | |
| padding="max_length", | |
| max_length=MAX_SEQ_LEN, | |
| ) | |
| tokens = {k: v.to(device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| outputs = protT5(**tokens) | |
| raw_embeddings = outputs.last_hidden_state # (1, L, D) | |
| input_ids = tokens["input_ids"] | |
| mask_2d = tokens["attention_mask"].clone() # (1, L) | |
| for sid in SPECIAL_IDS: | |
| mask_2d[input_ids == sid] = 0 | |
| mask_3d = mask_2d.unsqueeze(-1).float() # (1, L, 1) | |
| masked_embeddings = raw_embeddings * mask_3d # (1, L, D) | |
| return masked_embeddings, mask_2d | |
| ############################################ PREDICTION CODE ############################################ | |
| def ensemble_predict(embedding, mask, topk=20): | |
| scores = torch.zeros(1, len(unique_go_ids), device=device) | |
| counts = torch.zeros_like(scores) | |
| for model, adaptor in zip(MODELS, matrix_adaptators): | |
| preds = model(embedding, mask) # (num_go, 1) | |
| preds = preds.transpose(0, 1) # (1, num_go) | |
| adapted = adaptor @ preds # (unique_go, 1) | |
| adapted = adapted.T # (1, unique_go) | |
| scores += adapted | |
| counts += (torch.abs(adapted) > 1e-9).float() | |
| scores /= torch.clamp(counts, min=1) | |
| scores = scores.squeeze(0) | |
| mask = scores > THRESH | |
| scores = scores * mask.float() | |
| idx = torch.argsort(scores, descending=True)[:topk] | |
| return pd.DataFrame({ | |
| "GO_ID": unique_go_ids[idx.cpu().numpy()], | |
| "Score": scores[idx].round(decimals=3).detach().cpu().numpy() | |
| }) | |
| def predict(uniprot_id, topk): | |
| embedding, mask = generate_embedding_from_uniprot(uniprot_id) | |
| return ensemble_predict(embedding, mask, topk) | |
| ############################################ UNIPROTKB AND GRADIO UTILS ############################################ | |
| def fetch_human_uniprot_examples(n=500): | |
| url = "https://rest.uniprot.org/uniprotkb/search" | |
| params = { | |
| "query": "organism_id:9606 AND reviewed:true", | |
| "fields": "accession", | |
| "format": "json", | |
| "size": n | |
| } | |
| r = requests.get(url, params=params, timeout=10) | |
| r.raise_for_status() | |
| data = r.json() | |
| return [e["primaryAccession"] for e in data["results"]] | |
| def fetch_go_ancestors(go_id): | |
| url = ( | |
| f"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/" | |
| f"{go_id}/ancestors?relations=is_a,part_of" | |
| ) | |
| try: | |
| r = requests.get(url, headers={"Accept": "application/json"}, timeout=10) | |
| r.raise_for_status() | |
| except Exception as e: | |
| print(f"Warning: failed to fetch ancestors for {go_id}: {e}") | |
| return set() | |
| data = r.json() | |
| ancestors = set() | |
| for term in data.get("results", []): | |
| ancestors.update(term.get("ancestors", [])) | |
| return ancestors | |
| CAFA6_EVIDENCE_CODES = { | |
| "EXP", # Inferred from Experiment | |
| "IDA", # Inferred from Direct Assay | |
| "IPI", # Inferred from Physical Interaction | |
| "IMP", # Inferred from Mutant Phenotype | |
| "IGI", # Inferred from Genetic Interaction | |
| "IEP", # Inferred from Expression Pattern | |
| "TAS", # Traceable Author Statement | |
| "IC" # Inferred by Curator | |
| } | |
| def fetch_quickgo_annotations_with_ancestors(uniprot_id): | |
| direct_go_ids = set() | |
| pageSize = 200 | |
| page = 1 | |
| while True: | |
| url = ( | |
| f"https://www.ebi.ac.uk/QuickGO/services/annotation/search?" | |
| f"geneProductId=UniProtKB:{uniprot_id}&limit={pageSize}&page={page}" | |
| ) | |
| try: | |
| response = requests.get(url, headers={"Accept": "application/json"}, timeout=10) | |
| response.raise_for_status() | |
| except Exception as e: | |
| print(f"Warning: failed to fetch annotations: {e}") | |
| break | |
| data = response.json() | |
| results = data.get("results", []) | |
| if not results: | |
| break | |
| for item in results: | |
| go_id = item.get("goId") | |
| evidence = item.get("goEvidence") # fetch evidence code | |
| if go_id and evidence in CAFA6_EVIDENCE_CODES: | |
| direct_go_ids.add(go_id) | |
| page += 1 | |
| # Expand with ancestors | |
| all_go_ids = set(direct_go_ids) | |
| for go_id in direct_go_ids: | |
| ancestors = fetch_go_ancestors(go_id) | |
| all_go_ids.update(ancestors) | |
| return all_go_ids | |
| def color_topk_predictions(pred_df, true_go_ids): | |
| colors = [] | |
| for go_id in pred_df["GO_ID"]: | |
| if go_id in true_go_ids: | |
| colors.append("background-color: #d4edda") # green | |
| else: | |
| colors.append("background-color: #f8d7da") # red | |
| pred_df["Color"] = colors | |
| return pred_df | |
| def predictions_to_html(pred_df): | |
| html = "<div style='text-align: center;'>" | |
| html += "<table border='1' style='border-collapse: collapse; margin: 0 auto;'>" | |
| html += "<tr><th>GO_ID</th><th>Score</th></tr>" | |
| for _, row in pred_df.iterrows(): | |
| color = row['Color'] | |
| go_id = row['GO_ID'] | |
| go_url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" | |
| html += f"<tr style='{color}'>" | |
| html += ( | |
| f"<td>" | |
| f"<a href='{go_url}' target='_blank' rel='noopener noreferrer'>" | |
| f"{go_id}" | |
| f"</a>" | |
| f"</td>" | |
| ) | |
| html += f"<td>{row['Score']}</td>" | |
| html += "</tr>" | |
| html += "</table></div>" | |
| return html | |
| ############################################ GRADIO APP ############################################ | |
| markdown_information= r""" | |
| ## Trained on CAFA6 Protein Function Prediction Dataset | |
| ``` | |
| @misc{cafa-6-protein-function-prediction, | |
| author = {Iddo Friedberg and Predrag Radivojac and Paul D Thomas and An Phan and M. Clara De Paolis Kaluza and Damiano Piovesan and Parnal Joshi and Chris Mungall and Martyna Plomecka and Walter Reade and María Cruz}, | |
| title = {CAFA 6 Protein Function Prediction}, | |
| year = {2025}, | |
| howpublished = {\url{https://kaggle.com/competitions/cafa-6-protein-function-prediction}}, | |
| note = {Kaggle} | |
| } | |
| ``` | |
| ## SPROF-GO | |
| ``` | |
| @article{10.1093/bib/bbad117, | |
| author = {Yuan, Qianmu and Xie, Junjie and Xie, Jiancong and Zhao, Huiying and Yang, Yuedong}, | |
| title = "{Fast and accurate protein function prediction from sequence through pretrained language model and homology-based label diffusion}", | |
| journal = {Briefings in Bioinformatics}, | |
| year = {2023}, | |
| month = {03}, | |
| issn = {1477-4054}, | |
| doi = {10.1093/bib/bbad117}, | |
| url = {https://doi.org/10.1093/bib/bbad117} | |
| } | |
| ``` | |
| """ | |
| def predict_with_uniprot_highlight_html(uniprot_id, topk): | |
| embedding, mask = generate_embedding_from_uniprot(uniprot_id) | |
| pred_df = ensemble_predict(embedding, mask, topk) | |
| true_go_ids = fetch_quickgo_annotations_with_ancestors(uniprot_id) | |
| colored_df = color_topk_predictions(pred_df, true_go_ids) | |
| return predictions_to_html(colored_df) | |
| HUMAN_EXAMPLES = fetch_human_uniprot_examples() | |
| def format_human_examples_md(examples): | |
| md = "" | |
| md += " ".join( | |
| f"[{acc}](https://www.ebi.ac.uk/QuickGO/annotations?geneProductId={acc}) /" | |
| for acc in examples | |
| ) | |
| return md | |
| def fetch_human_examples_md(): | |
| examples = HUMAN_EXAMPLES.copy() | |
| random.shuffle(examples) | |
| return format_human_examples_md(examples[:40]) | |
| def make_title_md(uniprot_id): | |
| return f"### Prediction Table of [{uniprot_id}](https://www.ebi.ac.uk/QuickGO/annotations?geneProductId={uniprot_id})" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🦠 [SPROF-GO](https://github.com/biomed-AI/SPROF-GO) Ensemble Trained on [CAFA6](https://www.kaggle.com/competitions/cafa-6-protein-function-prediction/overview)") | |
| # ===================== Inputs ===================== | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Inference") | |
| gr.Markdown("⚠️ No label diffusion for fast inference.") | |
| uniprot_input = gr.Textbox(label="UniProtKB Protein ID", value="A1X283") | |
| topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top-K GO terms") | |
| run_btn = gr.Button("Predict") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Human Protein Examples") | |
| human_examples_md_comp = gr.Markdown(format_human_examples_md(HUMAN_EXAMPLES[:50])) | |
| example_btn = gr.Button("🔄 Fetch more examples") | |
| example_btn.click(fetch_human_examples_md, outputs=human_examples_md_comp) | |
| # ===================== Output ===================== | |
| gr.HTML("<hr style='margin:20px 0;'>") # horizontal divider | |
| title_md = gr.Markdown(make_title_md("A1X283")) | |
| html_output = gr.HTML(label="Predicted GO terms") | |
| gr.Markdown( | |
| "Top-K predictions colored green if predicted GO term is in " | |
| "[QuickGO](https://www.ebi.ac.uk/QuickGO/annotations) annotations, " | |
| "red otherwise" | |
| ) | |
| gr.Markdown(" Ancestors are propagated and will be colored green if any descendant is in QuickGO annotations") | |
| gr.Markdown(" I have only kept the following evidence codes: EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC to match CAFA6 evaluation") | |
| gr.HTML("<hr style='margin:20px 0;'>") # horizontal divider | |
| gr.Markdown(markdown_information) | |
| run_btn.click( | |
| fn=make_title_md, | |
| inputs=uniprot_input, | |
| outputs=title_md | |
| ) | |
| run_btn.click( | |
| fn=predict_with_uniprot_highlight_html, | |
| inputs=[uniprot_input, topk_slider], | |
| outputs=html_output | |
| ) | |
| demo.load( | |
| fn=lambda uid, k: (make_title_md(uid), predict_with_uniprot_highlight_html(uid, k)), | |
| inputs=[uniprot_input, topk_slider], | |
| outputs=[title_md, html_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |