Spaces:
Build error
Build error
| import os | |
| import re | |
| import numpy as np | |
| import torch | |
| import joblib | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModel | |
| from huggingface_hub import hf_hub_download | |
| from keras.models import load_model | |
| from goatools.obo_parser import GODag | |
| # βββββββββββββββββββ CONFIG βββββββββββββββββββ # | |
| SPACE_ID = "melvinalves/protein_function_prediction" | |
| TOP_N = 10 | |
| THRESH = 0.37 | |
| CHUNK_PB = 512 | |
| CHUNK_ESM = 1024 | |
| # βββββββββββββββββββ HELPERS βββββββββββββββββββ # | |
| def download_file(path): | |
| return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path) | |
| def load_keras(name): | |
| return load_model(download_file(f"models/{name}"), compile=False) | |
| def load_hf_encoder(model): | |
| tok = AutoTokenizer.from_pretrained(model, do_lower_case=False) | |
| mdl = AutoModel.from_pretrained(model) | |
| mdl.eval() | |
| return tok, mdl | |
| def embed_seq(model, seq, chunk): | |
| tok, mdl = load_hf_encoder(model) | |
| parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)] | |
| vecs = [] | |
| for p in parts: | |
| with torch.no_grad(): | |
| out = mdl(**tok(" ".join(p), return_tensors="pt", truncation=False)) | |
| vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy()) | |
| return np.mean(vecs, axis=0, keepdims=True) | |
| def load_go_info(): | |
| obo_path = download_file("data/go.obo") | |
| dag = GODag(obo_path, optional_attrs=['defn']) | |
| return {tid: (term.name, term.defn) for tid, term in dag.items()} | |
| GO_INFO = load_go_info() | |
| # βββββββββββββββββββ CARGA MODELOS βββββββββββββββββββ # | |
| mlp_pb = load_keras("mlp_protbert.h5") | |
| mlp_bfd = load_keras("mlp_protbertbfd.h5") | |
| mlp_esm = load_keras("mlp_esm2.h5") | |
| stacking = load_keras("ensemble_stack.h5") | |
| mlb = joblib.load(download_file("data/mlb_597.pkl")) | |
| GO = mlb.classes_ | |
| # βββββββββββββββββββ UI βββββββββββββββββββ # | |
| st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓnas") | |
| st.markdown( | |
| """ | |
| <style> textarea { font-size: 0.9rem !important; } </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| fasta_input = st.text_area("Insere uma ou mais sequΓͺncias FASTA:", height=300) | |
| predict_clicked = st.button("Prever GO terms") | |
| # βββββββββββββββββββ PARSE DE MΓLTIPLAS SEQUΓNCIAS βββββββββββββββββββ # | |
| def parse_fasta_multiple(fasta_str): | |
| entries = fasta_str.strip().split(">") | |
| parsed = [] | |
| for entry in entries: | |
| if not entry.strip(): | |
| continue | |
| lines = entry.strip().splitlines() | |
| header = lines[0].strip() | |
| seq = "".join(l.strip() for l in lines[1:]).replace(" ", "").upper() | |
| if seq: | |
| parsed.append((header, seq)) | |
| return parsed | |
| if predict_clicked: | |
| parsed_seqs = parse_fasta_multiple(fasta_input) | |
| if not parsed_seqs: | |
| st.warning("NΓ£o foi possΓvel encontrar nenhuma sequΓͺncia vΓ‘lida.") | |
| st.stop() | |
| for header, seq in parsed_seqs: | |
| with st.spinner(f"A processar {header}β¦ (pode demorar alguns minutos)"): | |
| emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB) | |
| emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB) | |
| emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM) | |
| y_pb = mlp_pb.predict(emb_pb) | |
| y_bfd = mlp_bfd.predict(emb_bfd) | |
| y_esm = mlp_esm.predict(emb_esm)[:, :597] | |
| X = np.concatenate([y_pb, y_bfd, y_esm], axis=1) | |
| y_ens = stacking.predict(X) | |
| # βββββββββββββββββββ RESULTADOS βββββββββββββββββββ # | |
| def mostrar(tag, y_pred): | |
| with st.expander(tag, expanded=True): | |
| hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0] | |
| st.markdown(f"**GO terms com prob β₯ {THRESH}**") | |
| if hits: | |
| for go_id in hits: | |
| name, defin = GO_INFO.get(go_id, ("β sem nome β", "")) | |
| limpa_def = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1', defin or "") | |
| st.write(f"**{go_id} β {name}**") | |
| st.caption(limpa_def) | |
| else: | |
| st.code("β nenhum β") | |
| st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**") | |
| for idx in np.argsort(-y_pred[0])[:TOP_N]: | |
| go_id = GO[idx] | |
| name, _ = GO_INFO.get(go_id, ("", "")) | |
| st.write(f"{go_id} β {name} : {y_pred[0][idx]:.4f}") | |
| # βββββββββββββββββββ ESCOLHE QUAIS MOSTRAR βββββββββββββββββββ # | |
| # mostrar(f"{header} β ProtBERT (MLP)", y_pb) | |
| # mostrar(f"{header} β ProtBERT-BFD (MLP)", y_bfd) | |
| # mostrar(f"{header} β ESM-2 (MLP)", y_esm) | |
| mostrar(f"{header}", y_ens) | |