melvinalves's picture
Update app.py
e0864d7 verified
raw
history blame
5.36 kB
import os
import re
import numpy as np
import torch
import joblib
import streamlit as st
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from keras.models import load_model
from goatools.obo_parser import GODag
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
SPACE_ID = "melvinalves/protein_function_prediction"
TOP_N = 10
THRESH = 0.37
CHUNK_PB = 512
CHUNK_ESM = 1024
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
@st.cache_resource
def download_file(path):
return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
@st.cache_resource
def load_keras(name):
return load_model(download_file(f"models/{name}"), compile=False)
@st.cache_resource
def load_hf_encoder(model):
tok = AutoTokenizer.from_pretrained(model, do_lower_case=False)
mdl = AutoModel.from_pretrained(model)
mdl.eval()
return tok, mdl
def embed_seq(model, seq, chunk):
tok, mdl = load_hf_encoder(model)
parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
vecs = []
for p in parts:
with torch.no_grad():
out = mdl(**tok(" ".join(p), return_tensors="pt", truncation=False))
vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
return np.mean(vecs, axis=0, keepdims=True)
@st.cache_resource
def load_go_info():
obo_path = download_file("data/go.obo")
dag = GODag(obo_path, optional_attrs=['defn'])
return {tid: (term.name, term.defn) for tid, term in dag.items()}
GO_INFO = load_go_info()
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CARGA MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
mlp_pb = load_keras("mlp_protbert.h5")
mlp_bfd = load_keras("mlp_protbertbfd.h5")
mlp_esm = load_keras("mlp_esm2.h5")
stacking = load_keras("ensemble_stack.h5")
mlb = joblib.load(download_file("data/mlb_597.pkl"))
GO = mlb.classes_
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓ­nas")
st.markdown(
"""
<style> textarea { font-size: 0.9rem !important; } </style>
""",
unsafe_allow_html=True,
)
fasta_input = st.text_area("Insere uma ou mais sequΓͺncias FASTA:", height=300)
predict_clicked = st.button("Prever GO terms")
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PARSE DE MÚLTIPLAS SEQUÊNCIAS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
def parse_fasta_multiple(fasta_str):
entries = fasta_str.strip().split(">")
parsed = []
for entry in entries:
if not entry.strip():
continue
lines = entry.strip().splitlines()
header = lines[0].strip()
seq = "".join(l.strip() for l in lines[1:]).replace(" ", "").upper()
if seq:
parsed.append((header, seq))
return parsed
if predict_clicked:
parsed_seqs = parse_fasta_multiple(fasta_input)
if not parsed_seqs:
st.warning("NΓ£o foi possΓ­vel encontrar nenhuma sequΓͺncia vΓ‘lida.")
st.stop()
for header, seq in parsed_seqs:
with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
y_pb = mlp_pb.predict(emb_pb)
y_bfd = mlp_bfd.predict(emb_bfd)
y_esm = mlp_esm.predict(emb_esm)[:, :597]
X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
y_ens = stacking.predict(X)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” RESULTADOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
def mostrar(tag, y_pred):
with st.expander(tag, expanded=True):
hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
if hits:
for go_id in hits:
name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", ""))
limpa_def = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1', defin or "")
st.write(f"**{go_id} β€” {name}**")
st.caption(limpa_def)
else:
st.code("β€” nenhum β€”")
st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
for idx in np.argsort(-y_pred[0])[:TOP_N]:
go_id = GO[idx]
name, _ = GO_INFO.get(go_id, ("", ""))
st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” ESCOLHE QUAIS MOSTRAR β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
# mostrar(f"{header} β€” ProtBERT (MLP)", y_pb)
# mostrar(f"{header} β€” ProtBERT-BFD (MLP)", y_bfd)
# mostrar(f"{header} β€” ESM-2 (MLP)", y_esm)
mostrar(f"{header}", y_ens)