melvinalves's picture
Update app.py
a5e2965 verified
raw
history blame
8.43 kB
# -------------------------------------------------------------------------------------------------
# app.py – Streamlit app para prediΓ§Γ£o de GO:MF
# β€’ ProtBERT / ProtBERT-BFD fine-tuned (melvinalves/FineTune)
# β€’ ESM-2 base (facebook/esm2_t33_650M_UR50D)
# -------------------------------------------------------------------------------------------------
import os, re, numpy as np, torch, joblib, streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModel
from keras.models import load_model
from goatools.obo_parser import GODag
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” AUTENTICAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
login(os.environ["HF_TOKEN"])
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
SPACE_ID = "melvinalves/protein_function_prediction"
TOP_N = 10
THRESH = 0.37
CHUNK_PB = 512 # janela ProtBERT / ProtBERT-BFD
CHUNK_ESM = 1024 # janela ESM-2
# repositΓ³rios HF
FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd")
BASE_ESM = "facebook/esm2_t33_650M_UR50D"
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
@st.cache_resource
def download_file(path):
"""Ficheiros pequenos (≀1 GB) guardados no Space."""
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
@st.cache_resource
def load_keras(name):
"""Carrega modelos Keras (MLPs e stacking)."""
return load_model(download_file(f"models/{name}"), compile=False)
# ---------- carregar tokenizer + encoder ----------
@st.cache_resource
def load_hf_encoder(repo_id, subfolder=None, base_tok=None):
"""
β€’ repo_id : repositΓ³rio HF ou caminho local
β€’ subfolder : subpasta onde vivem pesos/config (None se nΓ£o houver)
β€’ base_tok : repo para o tokenizer (None => usa repo_id)
Converte tf_model.h5 β†’ PyTorch on-the-fly (from_tf=True).
"""
if base_tok is None:
base_tok = repo_id
tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False)
kwargs = dict(from_tf=True)
if subfolder:
kwargs["subfolder"] = subfolder
mdl = AutoModel.from_pretrained(repo_id, **kwargs)
mdl.eval()
return tok, mdl
# ---------- extrair embedding ----------
def embed_seq(model_ref, seq, chunk):
"""
β€’ model_ref = string (modelo base) OU tuple(repo_id, subfolder) (modelo fine-tuned)
Retorna embedding CLS mΓ©dio (caso a sequΓͺncia seja dividida em chunks).
"""
if isinstance(model_ref, tuple): # ProtBERT / ProtBERT-BFD fine-tuned
repo_id, subf = model_ref
tok, mdl = load_hf_encoder(repo_id, subfolder=subf,
base_tok="Rostlab/prot_bert")
else: # modelo base (ESM-2)
tok, mdl = load_hf_encoder(model_ref)
parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
vecs = []
for p in parts:
toks = tok(" ".join(p), return_tensors="pt", truncation=False)
with torch.no_grad():
out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()})
vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy())
return np.mean(vecs, axis=0)
@st.cache_resource
def load_go_info():
"""LΓͺ GO.obo e devolve dicionΓ‘rio id β†’ (name, definition)."""
obo_path = download_file("data/go.obo")
dag = GODag(obo_path, optional_attrs=["defn"])
return {tid: (term.name, term.defn) for tid, term in dag.items()}
GO_INFO = load_go_info()
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CARGA MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
mlp_pb = load_keras("mlp_protbert.h5")
mlp_bfd = load_keras("mlp_protbertbfd.h5")
mlp_esm = load_keras("mlp_esm2.h5")
stacking = load_keras("ensemble_stack.h5")
mlb = joblib.load(download_file("data/mlb_597.pkl"))
GO = mlb.classes_
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓ­nas")
# Pequeno ajuste de fonte no textarea
st.markdown("<style> textarea { font-size: 0.9rem !important; } </style>",
unsafe_allow_html=True)
fasta_input = st.text_area("Insere uma ou mais sequΓͺncias FASTA:", height=300)
predict_clicked = st.button("Prever GO terms")
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PARSE DE MÚLTIPLAS SEQUÊNCIAS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
def parse_fasta_multiple(fasta_str):
"""
Devolve lista de (header, seq) a partir de texto FASTA possivelmente mΓΊltiplo.
Suporta bloco inicial sem '>'.
"""
entries, parsed = fasta_str.strip().split(">"), []
for i, entry in enumerate(entries):
if not entry.strip():
continue
lines = entry.strip().splitlines()
if i > 0: # bloco tΓ­pico FASTA
header = lines[0].strip()
seq = "".join(lines[1:]).replace(" ", "").upper()
else: # sequΓͺncia sem '>'
header = f"Seq_{i+1}"
seq = "".join(lines).replace(" ", "").upper()
if seq:
parsed.append((header, seq))
return parsed
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” INFERÊNCIA β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
if predict_clicked:
parsed_seqs = parse_fasta_multiple(fasta_input)
if not parsed_seqs:
st.warning("NΓ£o foi possΓ­vel encontrar nenhuma sequΓͺncia vΓ‘lida.")
st.stop()
for header, seq in parsed_seqs:
with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDINGS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PREDIÇÕES MLPs β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
y_pb = mlp_pb.predict(emb_pb)
y_bfd = mlp_bfd.predict(emb_bfd)
y_esm = mlp_esm.predict(emb_esm)[:, :597] # alinhar nΒΊ de termos
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” STACKING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
y_ens = stacking.predict(X)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” RESULTADOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
def mostrar(tag, y_pred):
with st.expander(tag, expanded=True):
# GO terms acima do threshold
st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
if hits:
for go_id in hits:
name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", ""))
defin = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1',
defin or "")
st.write(f"**{go_id} β€” {name}**")
st.caption(defin)
else:
st.code("β€” nenhum β€”")
# Top-N mais provΓ‘veis
st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
for idx in np.argsort(-y_pred[0])[:TOP_N]:
go_id = GO[idx]
name, _ = GO_INFO.get(go_id, ("", ""))
st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” ESCOLHE QUAIS MOSTRAR β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
# Descomenta se quiseres ver as saΓ­das individuais
# mostrar(f"{header} β€” ProtBERT (MLP)", y_pb)
# mostrar(f"{header} β€” ProtBERT-BFD (MLP)", y_bfd)
# mostrar(f"{header} β€” ESM-2 (MLP)", y_esm)
mostrar(header, y_ens) # ensemble