melvinalves commited on
Commit
be01d59
Β·
verified Β·
1 Parent(s): 4ee1dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -14
app.py CHANGED
@@ -6,11 +6,12 @@ import streamlit as st
6
  from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
 
9
 
10
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
  SPACE_ID = "melvinalves/protein_function_prediction"
12
  TOP_N = 10
13
- THRESH = 0.35
14
  CHUNK_PB = 512
15
  CHUNK_ESM = 1024
16
 
@@ -40,6 +41,14 @@ def embed_seq(model, seq, chunk):
40
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
41
  return np.mean(vecs, axis=0, keepdims=True)
42
 
 
 
 
 
 
 
 
 
43
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CARGA MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
44
  mlp_pb = load_keras("mlp_protbert.h5")
45
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
@@ -60,24 +69,35 @@ st.markdown(
60
  )
61
 
62
  fasta_input = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
 
 
 
 
 
 
63
  predict_clicked = st.button("Prever GO terms")
64
 
65
  if predict_clicked:
66
 
67
- # β€”β€”β€” ValidaΓ§Γ£o mΓ­nima β€”β€”β€”
68
- seq = "".join(l.strip() for l in fasta_input.splitlines() if not l.startswith(">")).replace(" ", "").upper()
 
 
 
69
  if not seq:
70
  st.warning("Por favor, insere primeiro uma sequΓͺncia FASTA vΓ‘lida.")
71
  st.stop()
72
 
 
 
73
 
74
- # β€”β€”β€” 1) EMBEDDINGS β€”β€”β€”
75
  with st.spinner("⏳ A gerar embeddings…"):
76
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
77
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
78
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
79
 
80
- # β€”β€”β€” 2) PREDIÇÕES β€”β€”β€”
81
  with st.spinner("🧠 A fazer prediΓ§Γ΅es…"):
82
  y_pb = mlp_pb.predict(emb_pb)
83
  y_bfd = mlp_bfd.predict(emb_bfd)
@@ -85,17 +105,32 @@ if predict_clicked:
85
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
86
  y_ens = stacking.predict(X)
87
 
88
- # β€”β€”β€” 3) MOSTRAR RESULTADOS β€”β€”β€”
89
  def mostrar(tag, y_pred):
90
- with st.expander(tag, expanded=(tag == "Ensemble (Stacking)")):
91
  hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
92
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
93
- st.code("\n".join(hits) if hits else "β€” nenhum β€”")
 
 
 
 
 
 
 
94
  st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
95
- for i in np.argsort(-y_pred[0])[:TOP_N]:
96
- st.write(f"{GO[i]} : {y_pred[0][i]:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- #mostrar("ProtBERT (MLP)", y_pb)
99
- #mostrar("ProtBERT-BFD (MLP)", y_bfd)
100
- #mostrar("ESM-2 (MLP)", y_esm)
101
- mostrar("Ensemble (Stacking)", y_ens)
 
6
  from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
+ from goatools.obo_parser import GODag
10
 
11
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
12
  SPACE_ID = "melvinalves/protein_function_prediction"
13
  TOP_N = 10
14
+ THRESH = 0.37
15
  CHUNK_PB = 512
16
  CHUNK_ESM = 1024
17
 
 
41
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
42
  return np.mean(vecs, axis=0, keepdims=True)
43
 
44
+ @st.cache_resource
45
+ def load_go_info():
46
+ obo_path = download_file("data/go.obo")
47
+ dag = GODag(obo_path, optional_attrs=['defn'])
48
+ return {tid: (term.name, term.defn) for tid, term in dag.items()}
49
+
50
+ GO_INFO = load_go_info()
51
+
52
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CARGA MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
53
  mlp_pb = load_keras("mlp_protbert.h5")
54
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
 
69
  )
70
 
71
  fasta_input = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
72
+ selected_model = st.selectbox("Modelo a utilizar:", [
73
+ "ProtBERT (MLP)",
74
+ "ProtBERT-BFD (MLP)",
75
+ "ESM-2 (MLP)",
76
+ "Ensemble (Stacking)"
77
+ ])
78
  predict_clicked = st.button("Prever GO terms")
79
 
80
  if predict_clicked:
81
 
82
+ # ——— 1) PRÉ-PROCESSAMENTO FASTA ———
83
+ lines = fasta_input.splitlines()
84
+ header = next((l for l in lines if l.startswith(">")), None)
85
+ seq = "".join(l.strip() for l in lines if not l.startswith(">")).replace(" ", "").upper()
86
+
87
  if not seq:
88
  st.warning("Por favor, insere primeiro uma sequΓͺncia FASTA vΓ‘lida.")
89
  st.stop()
90
 
91
+ if header:
92
+ st.markdown(f"**🧬 ID da proteína:** `{header[1:].strip()}`")
93
 
94
+ # β€”β€”β€” 2) EMBEDDINGS β€”β€”β€”
95
  with st.spinner("⏳ A gerar embeddings…"):
96
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
97
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
98
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
99
 
100
+ # β€”β€”β€” 3) PREDIÇÕES β€”β€”β€”
101
  with st.spinner("🧠 A fazer prediΓ§Γ΅es…"):
102
  y_pb = mlp_pb.predict(emb_pb)
103
  y_bfd = mlp_bfd.predict(emb_bfd)
 
105
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
106
  y_ens = stacking.predict(X)
107
 
108
+ # β€”β€”β€” 4) RESULTADOS β€”β€”β€”
109
  def mostrar(tag, y_pred):
110
+ with st.expander(tag, expanded=True):
111
  hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
112
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
113
+ if hits:
114
+ for go_id in hits:
115
+ name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", "β€” sem definiΓ§Γ£o β€”"))
116
+ st.write(f"**{go_id} β€” {name}**")
117
+ st.caption(defin)
118
+ else:
119
+ st.code("β€” nenhum β€”")
120
+
121
  st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
122
+ for idx in np.argsort(-y_pred[0])[:TOP_N]:
123
+ go_id = GO[idx]
124
+ name, _ = GO_INFO.get(go_id, ("", ""))
125
+ st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
126
+
127
+ # β€”β€”β€” 5) MOSTRAR RESULTADO DO MODELO ESCOLHIDO β€”β€”β€”
128
+ if selected_model == "ProtBERT (MLP)":
129
+ mostrar("ProtBERT (MLP)", y_pb)
130
+ elif selected_model == "ProtBERT-BFD (MLP)":
131
+ mostrar("ProtBERT-BFD (MLP)", y_bfd)
132
+ elif selected_model == "ESM-2 (MLP)":
133
+ mostrar("ESM-2 (MLP)", y_esm)
134
+ else:
135
+ mostrar("Ensemble (Stacking)", y_ens)
136