Spaces:
Build error
Build error
| # app.py – Streamlit app para predição de GO:MF | |
| # ProtBERT / ProtBERT-BFD fine-tuned (melvinalves/FineTune) | |
| # ESM-2 base (facebook/esm2_t33_650M_UR50D) | |
| import os, re, numpy as np, torch, joblib, streamlit as st | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModel | |
| from keras.models import load_model | |
| from goatools.obo_parser import GODag | |
| # AUTENTICAÇÃO # | |
| login(os.environ["HF_TOKEN"]) | |
| # CONFIG # | |
| SPACE_ID = "melvinalves/protein_function_prediction" | |
| TOP_N = 20 | |
| THRESH = 0.37 | |
| CHUNK_PB = 512 | |
| CHUNK_ESM = 1024 | |
| # REPOSITÓRIOS HF | |
| FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert") | |
| FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd") | |
| BASE_ESM = "facebook/esm2_t33_650M_UR50D" | |
| # HELPERS # | |
| def download_file(path): | |
| """Ficheiros pequenos (≤1 GB) guardados no Space.""" | |
| from huggingface_hub import hf_hub_download | |
| return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path) | |
| def load_keras(name): | |
| """Carrega modelos Keras (MLPs e stacking).""" | |
| return load_model(download_file(f"models/{name}"), compile=False) | |
| def load_hf_encoder(repo_id, subfolder=None, base_tok=None): | |
| """Carrega tokenizer + encoder; converte TF-weights → PyTorch on-the-fly.""" | |
| if base_tok is None: | |
| base_tok = repo_id | |
| tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False) | |
| kwargs = dict(from_tf=True) | |
| if subfolder: | |
| kwargs["subfolder"] = subfolder | |
| mdl = AutoModel.from_pretrained(repo_id, **kwargs) | |
| mdl.eval() | |
| return tok, mdl | |
| def embed_seq(model_ref, seq, chunk): | |
| """Devolve embedding CLS médio; corta seq. longa em chunks se preciso.""" | |
| if isinstance(model_ref, tuple): # ProtBERT fine-tuned | |
| repo_id, subf = model_ref | |
| tok, mdl = load_hf_encoder(repo_id, subfolder=subf, | |
| base_tok="Rostlab/prot_bert") | |
| else: # modelo base ESM-2 | |
| tok, mdl = load_hf_encoder(model_ref) | |
| parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)] | |
| vecs = [] | |
| for p in parts: | |
| toks = tok(" ".join(p), return_tensors="pt", truncation=False) | |
| with torch.no_grad(): | |
| out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()}) | |
| vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy()) | |
| return np.mean(vecs, axis=0) | |
| def load_go_info(): | |
| """Lê GO.obo e devolve {id: (name, definition bruta)}.""" | |
| dag = GODag(download_file("data/go.obo"), optional_attrs=["defn"]) | |
| return {tid: (term.name, term.defn) for tid, term in dag.items()} | |
| GO_INFO = load_go_info() | |
| # MODELOS # | |
| mlp_pb = load_keras("mlp_protbert.h5") | |
| mlp_bfd = load_keras("mlp_protbertbfd.h5") | |
| mlp_esm = load_keras("mlp_esm2.h5") | |
| stacking = load_keras("ensemble_stack.h5") | |
| mlb = joblib.load(download_file("data/mlb_597.pkl")) | |
| GO = mlb.classes_ | |
| # UI # | |
| st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas", | |
| page_icon="🧬", layout="centered") | |
| st.markdown( | |
| """ | |
| <style> | |
| body, .stApp { background:#FFFFFF !important; } | |
| .block-container { padding-top:1.5rem; } | |
| textarea { font-size:0.9rem !important; } | |
| div[data-testid="column"]:first-child { | |
| border-right:1px solid #E0E0E0; padding-right:1rem !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| if os.path.exists("logo.png"): | |
| st.image("logo.png", width=180) | |
| st.title("Predição de Funções Moleculares de Proteínas (GO:MF)") | |
| fasta_input = st.text_area("Insere uma ou mais sequências FASTA:", height=300) | |
| predict_clicked = st.button("Prever GO terms") | |
| # UTILITÁRIOS # | |
| def parse_fasta_multiple(text): | |
| """Extrai [(header, seq)] de texto FASTA (bloco inicial sem '>' suportado).""" | |
| out = [] | |
| for i, blk in enumerate(text.strip().split(">")): | |
| if not blk.strip(): | |
| continue | |
| lines = blk.strip().splitlines() | |
| header = lines[0].strip() if i else f"Seq_{i+1}" | |
| seq = "".join(lines[1:] if i else lines).replace(" ", "").upper() | |
| if seq: | |
| out.append((header, seq)) | |
| return out | |
| def clean_definition(defin: str) -> str: | |
| """ | |
| Retorna apenas o texto dentro das primeiras aspas. | |
| Se não houver aspas, devolve texto antes do primeiro '['. | |
| """ | |
| if not defin: | |
| return "" | |
| m = re.search(r'"([^"]+)"', defin) | |
| if m: | |
| return m.group(1).strip() | |
| return defin.split("[", 1)[0].strip() | |
| def go_link(go_id, name=""): | |
| url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" | |
| return f"[{go_id} - {name}]({url})" if name else f"[{go_id}]({url})" | |
| # MOSTRAR RESULTADOS # | |
| def mostrar(header, y_pred): | |
| pid = header.split()[0] | |
| uniprot = f"https://www.uniprot.org/uniprotkb/{pid}" | |
| with st.expander(header, expanded=True): | |
| st.markdown( | |
| f""" | |
| <div style="text-align:right;margin-bottom:0.5rem"> | |
| <a href="{uniprot}" target="_blank"> | |
| <button style="background:#2b8cbe;border:none;border-radius:4px; | |
| padding:0.35rem 0.8rem;color:#fff;font-size:0.9rem; | |
| cursor:pointer">Visitar UniProt</button> | |
| </a> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| col1, col2 = st.columns(2) | |
| # coluna 1 : ≥ threshold | |
| with col1: | |
| st.markdown(f"**GO terms com prob ≥ {THRESH}**") | |
| hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0] | |
| if hits: | |
| for go_id in hits: | |
| name, defin_raw = GO_INFO.get(go_id, ("- sem nome -", "")) | |
| defin = clean_definition(defin_raw) | |
| st.markdown(f"- {go_link(go_id, name)}") | |
| if defin: | |
| st.caption(defin) | |
| else: | |
| st.code("- nenhum -") | |
| # coluna 2 : Top-20 | |
| with col2: | |
| st.markdown(f"**Top {TOP_N} GO terms mais prováveis**") | |
| for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], 1): | |
| go_id = GO[idx] | |
| name, _ = GO_INFO.get(go_id, ("", "")) | |
| st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}") | |
| # INFERÊNCIA # | |
| if predict_clicked: | |
| for header, seq in parse_fasta_multiple(fasta_input): | |
| with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"): | |
| emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB) | |
| emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB) | |
| emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM) | |
| y_pb = mlp_pb.predict(emb_pb) | |
| y_bfd = mlp_bfd.predict(emb_bfd) | |
| y_esm = mlp_esm.predict(emb_esm)[:, :597] | |
| y_ens = stacking.predict(np.concatenate([y_pb, y_bfd, y_esm], axis=1)) | |
| mostrar(header, y_ens) | |
| # LISTA COMPLETA COM BARRA DE PESQUISA # | |
| with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False): | |
| search_term = st.text_input("Filtra GO term ou nome:") | |
| # aplicar filtro | |
| filtered_go_terms = [] | |
| for go_id in GO: | |
| name, _ = GO_INFO.get(go_id, ("", "")) | |
| if search_term.strip().lower() in go_id.lower() or search_term.strip().lower() in name.lower(): | |
| filtered_go_terms.append((go_id, name)) | |
| # mostrar por colunas | |
| if filtered_go_terms: | |
| cols = st.columns(3) | |
| for i, (go_id, name) in enumerate(filtered_go_terms): | |
| cols[i % 3].markdown(f"- {go_link(go_id, name)}") | |
| else: | |
| st.info("Nenhum GO term corresponde ao filtro inserido.") | |