############################################ INSTALL PACKAGES ############################################ import sys import subprocess import os def install(package): # Add --upgrade to force install the latest version subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package]) install("gradio>=3.44.0") install("biopython==1.86") install("cachetools==5.4.0") install("mlflow==3.7.0") HF_REPO_URL = "https://gitlab.com/nn_projects/cafa6_project" CLONE_DIR = "/root/cafa6_project" if not os.path.exists(CLONE_DIR): os.system(f"git clone {HF_REPO_URL} {CLONE_DIR}") os.chdir(CLONE_DIR) ############################################ DEFINE CONSTANTS ############################################ import gc import pandas as pd import numpy as np from collections import defaultdict from tqdm.auto import tqdm import mlflow import torch import random import requests import re from transformers import set_seed from torch.utils.data import Dataset, DataLoader from Bio import SeqIO import gradio as gr input_path = './' data_dir = "numpy_dataset/" test_embeddings_data = "prot_t5_embeddings_right_pooling_False_test_mini" test_batch_size = 64 SEED = 42 MAX_SEQ_LEN = 512 HIDDEN_DIM = 1024 THRESH = 0.003 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) set_seed(SEED) ############################################ LOAD MODELS ############################################ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("USING DEVICE: ", device) run_id = '9ee8f63638d0494ea20b63710c19a8b3' SUBMISSION_INPUT_PATH = input_path + 'mlruns/11/' + run_id + '/artifacts/' SUBMISSION_INPUT = SUBMISSION_INPUT_PATH+ 'submission.tsv' OUTPUT_PATH = SUBMISSION_INPUT_PATH + '/diamond/' models_uri = np.load(SUBMISSION_INPUT_PATH + "models_uri_C_fold_4.npy") mlb_arrays_uri = np.load(SUBMISSION_INPUT_PATH + "mlb_arrays_uri_C_fold_4.npy") #LOAD MODELS MODELS = [] for uri in models_uri: model = mlflow.pytorch.load_model(uri, map_location=torch.device(device)) model.eval() model.to(device) MODELS.append(model) #LOAD ONE HOT MLB ARRAYS MLB_ARRAYS = [ np.load(uri, allow_pickle=True) for uri in mlb_arrays_uri ] #CREATE MATRIX ADAPTATORS TO MAP UNIQUE GO IDS TO EACH MODEL'S MLB ARRAY PREDICTION concatenated_array = np.concatenate(MLB_ARRAYS) unique_go_ids = np.array(list(set(concatenated_array))) print(concatenated_array.shape) print(unique_go_ids.shape) matrix_adaptators = [] for n in range(0, len(MLB_ARRAYS)): mlb_array = MLB_ARRAYS[n] # shape: (num_labels,) unique_prot_ids_array = unique_go_ids # shape: (num_proteins,) prob_matrix_adaptator = torch.zeros(len(unique_go_ids), len(mlb_array)) for i in range(0, len(unique_go_ids)): for j in range(0, len(mlb_array)): if unique_go_ids[i] == mlb_array[j]: prob_matrix_adaptator[i, j] = 1.0 print(n, " " , prob_matrix_adaptator.shape) matrix_adaptators.append(prob_matrix_adaptator.to(device)) ############################################ UNIPROTKB PROT5 EMBEDDINGS ############################################ from transformers import T5Tokenizer, T5EncoderModel model_type = "Rostlab/prot_t5_xl_uniref50" tokenizer = T5Tokenizer.from_pretrained(model_type, do_lower_case=False, truncation_side = "right") #do not put to lower case, prot T5 needs upper case letters protT5 = T5EncoderModel.from_pretrained(model_type, trust_remote_code=True).to(device) max_sequence_len = 512 #prot_t5_xl pretraining done on max 512 seq len batch_size = 64 SPECIAL_IDS = set(tokenizer.all_special_ids) # Freeze params, inference only protT5.eval() for param in protT5.parameters(): param.requires_grad = False def preprocess_sequence(seq: str) -> str: seq = " ".join(seq) seq = re.sub(r"[UZOB]", "X", seq) return seq def fetch_uniprot_sequence(uniprot_id: str) -> str: url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" r = requests.get(url, timeout=10) if r.status_code != 200: raise ValueError(f"UniProt ID '{uniprot_id}' not found") fasta = r.text.splitlines() return "".join(line for line in fasta if not line.startswith(">")) def generate_embedding_from_uniprot(uniprot_id: str): seq = fetch_uniprot_sequence(uniprot_id) seq = preprocess_sequence(seq) tokens = tokenizer( seq, return_tensors="pt", truncation=True, add_special_tokens=True, padding="max_length", max_length=MAX_SEQ_LEN, ) tokens = {k: v.to(device) for k, v in tokens.items()} with torch.no_grad(): outputs = protT5(**tokens) raw_embeddings = outputs.last_hidden_state # (1, L, D) input_ids = tokens["input_ids"] mask_2d = tokens["attention_mask"].clone() # (1, L) for sid in SPECIAL_IDS: mask_2d[input_ids == sid] = 0 mask_3d = mask_2d.unsqueeze(-1).float() # (1, L, 1) masked_embeddings = raw_embeddings * mask_3d # (1, L, D) return masked_embeddings, mask_2d ############################################ PREDICTION CODE ############################################ def ensemble_predict(embedding, mask, topk=20): scores = torch.zeros(1, len(unique_go_ids), device=device) counts = torch.zeros_like(scores) for model, adaptor in zip(MODELS, matrix_adaptators): preds = model(embedding, mask) # (num_go, 1) preds = preds.transpose(0, 1) # (1, num_go) adapted = adaptor @ preds # (unique_go, 1) adapted = adapted.T # (1, unique_go) scores += adapted counts += (torch.abs(adapted) > 1e-9).float() scores /= torch.clamp(counts, min=1) scores = scores.squeeze(0) mask = scores > THRESH scores = scores * mask.float() idx = torch.argsort(scores, descending=True)[:topk] return pd.DataFrame({ "GO_ID": unique_go_ids[idx.cpu().numpy()], "Score": scores[idx].round(decimals=3).detach().cpu().numpy() }) def predict(uniprot_id, topk): embedding, mask = generate_embedding_from_uniprot(uniprot_id) return ensemble_predict(embedding, mask, topk) ############################################ UNIPROTKB AND GRADIO UTILS ############################################ def fetch_human_uniprot_examples(n=500): url = "https://rest.uniprot.org/uniprotkb/search" params = { "query": "organism_id:9606 AND reviewed:true", "fields": "accession", "format": "json", "size": n } r = requests.get(url, params=params, timeout=10) r.raise_for_status() data = r.json() return [e["primaryAccession"] for e in data["results"]] def fetch_go_ancestors(go_id): url = ( f"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/" f"{go_id}/ancestors?relations=is_a,part_of" ) try: r = requests.get(url, headers={"Accept": "application/json"}, timeout=10) r.raise_for_status() except Exception as e: print(f"Warning: failed to fetch ancestors for {go_id}: {e}") return set() data = r.json() ancestors = set() for term in data.get("results", []): ancestors.update(term.get("ancestors", [])) return ancestors CAFA6_EVIDENCE_CODES = { "EXP", # Inferred from Experiment "IDA", # Inferred from Direct Assay "IPI", # Inferred from Physical Interaction "IMP", # Inferred from Mutant Phenotype "IGI", # Inferred from Genetic Interaction "IEP", # Inferred from Expression Pattern "TAS", # Traceable Author Statement "IC" # Inferred by Curator } def fetch_quickgo_annotations_with_ancestors(uniprot_id): direct_go_ids = set() pageSize = 200 page = 1 while True: url = ( f"https://www.ebi.ac.uk/QuickGO/services/annotation/search?" f"geneProductId=UniProtKB:{uniprot_id}&limit={pageSize}&page={page}" ) try: response = requests.get(url, headers={"Accept": "application/json"}, timeout=10) response.raise_for_status() except Exception as e: print(f"Warning: failed to fetch annotations: {e}") break data = response.json() results = data.get("results", []) if not results: break for item in results: go_id = item.get("goId") evidence = item.get("goEvidence") # fetch evidence code if go_id and evidence in CAFA6_EVIDENCE_CODES: direct_go_ids.add(go_id) page += 1 # Expand with ancestors all_go_ids = set(direct_go_ids) for go_id in direct_go_ids: ancestors = fetch_go_ancestors(go_id) all_go_ids.update(ancestors) return all_go_ids def color_topk_predictions(pred_df, true_go_ids): colors = [] for go_id in pred_df["GO_ID"]: if go_id in true_go_ids: colors.append("background-color: #d4edda") # green else: colors.append("background-color: #f8d7da") # red pred_df["Color"] = colors return pred_df def predictions_to_html(pred_df): html = "
| GO_ID | Score |
|---|---|
| " f"" f"{go_id}" f"" f" | " ) html += f"{row['Score']} | " html += "