melvinalves commited on
Commit
1dadffc
Β·
verified Β·
1 Parent(s): c6dfc57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -46
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # -------------------------------------------------------------------------------------------------
2
  # app.py – Streamlit app para prediΓ§Γ£o de GO:MF
3
- # VersΓ£o: usa ProtBERT & ProtBERT-BFD fine-tuned (melvinalves/FineTune) + ESM-2 base
 
4
  # -------------------------------------------------------------------------------------------------
5
  import os, re, numpy as np, torch, joblib, streamlit as st
6
  from huggingface_hub import login
@@ -8,17 +9,17 @@ from transformers import AutoTokenizer, AutoModel
8
  from keras.models import load_model
9
  from goatools.obo_parser import GODag
10
 
11
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” AUTHENTICAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
12
  login(os.environ["HF_TOKEN"])
13
 
14
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
15
  SPACE_ID = "melvinalves/protein_function_prediction"
16
  TOP_N = 10
17
  THRESH = 0.37
18
- CHUNK_PB = 512
19
- CHUNK_ESM = 1024
20
 
21
- # RepositΓ³rios dos modelos
22
  FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
23
  FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd")
24
  BASE_ESM = "facebook/esm2_t33_650M_UR50D"
@@ -26,53 +27,60 @@ BASE_ESM = "facebook/esm2_t33_650M_UR50D"
26
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
27
  @st.cache_resource
28
  def download_file(path):
29
- """Ficheiros pequenos guardados no repositΓ³rio do Space (≀1 GB total)."""
30
  from huggingface_hub import hf_hub_download
31
  return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
32
 
33
  @st.cache_resource
34
  def load_keras(name):
35
- """Carrega modelos Keras (MLPs + stacking)."""
36
  return load_model(download_file(f"models/{name}"), compile=False)
37
 
 
38
  @st.cache_resource
39
- def load_hf_encoder(repo_id, subfolder=None, base_tok="Rostlab/prot_bert"):
40
  """
41
- Carrega um encoder HF (PyTorch) – se existir apenas tf_model.h5 no repo,
42
- usa from_tf=True para converter on-the-fly.
 
 
43
  """
 
 
44
  tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False)
45
- mdl = AutoModel.from_pretrained(
46
- repo_id,
47
- subfolder=subfolder,
48
- from_tf=True, # converte pesos TF se necessΓ‘rio
49
- )
50
  mdl.eval()
51
  return tok, mdl
52
 
 
53
  def embed_seq(model_ref, seq, chunk):
54
  """
55
- Extrai embedding mΓ©dio (CLS) para sequΓͺncias grandes usando chunks.
56
- - model_ref pode ser string (modelo base) ou tuple (repo_id, subfolder) p/ fine-tuned.
57
  """
58
- if isinstance(model_ref, tuple):
59
- tok, mdl = load_hf_encoder(*model_ref)
60
- else:
61
- # mantΓ©m o tokenizer apropriado
62
- base_tok = "Rostlab/prot_bert" if "prot_bert" in model_ref else model_ref
63
- tok, mdl = load_hf_encoder(model_ref, base_tok=base_tok)
64
 
65
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
66
- vecs = []
67
  for p in parts:
68
- tokens = tok(" ".join(p), return_tensors="pt", truncation=False)
69
  with torch.no_grad():
70
- out = mdl(**{k: v.to(mdl.device) for k, v in tokens.items()})
71
  vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy())
72
  return np.mean(vecs, axis=0, keepdims=True)
73
 
74
  @st.cache_resource
75
  def load_go_info():
 
76
  obo_path = download_file("data/go.obo")
77
  dag = GODag(obo_path, optional_attrs=["defn"])
78
  return {tid: (term.name, term.defn) for tid, term in dag.items()}
@@ -91,25 +99,28 @@ GO = mlb.classes_
91
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
92
  st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓ­nas")
93
 
94
- st.markdown(
95
- "<style> textarea { font-size: 0.9rem !important; } </style>",
96
- unsafe_allow_html=True,
97
- )
98
 
99
  fasta_input = st.text_area("Insere uma ou mais sequΓͺncias FASTA:", height=300)
100
  predict_clicked = st.button("Prever GO terms")
101
 
102
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PARSE DE MÚLTIPLAS SEQUÊNCIAS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
103
  def parse_fasta_multiple(fasta_str):
 
 
 
 
104
  entries, parsed = fasta_str.strip().split(">"), []
105
  for i, entry in enumerate(entries):
106
  if not entry.strip():
107
  continue
108
  lines = entry.strip().splitlines()
109
- if i > 0:
110
  header = lines[0].strip()
111
  seq = "".join(lines[1:]).replace(" ", "").upper()
112
- else:
113
  header = f"Seq_{i+1}"
114
  seq = "".join(lines).replace(" ", "").upper()
115
  if seq:
@@ -125,42 +136,46 @@ if predict_clicked:
125
 
126
  for header, seq in parsed_seqs:
127
  with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
128
- # β€”β€”β€” Embeddings β€”β€”β€” #
129
  emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
130
  emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
131
  emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
132
 
133
- # β€”β€”β€” PrediΓ§Γ΅es dos MLPs β€”β€”β€” #
134
  y_pb = mlp_pb.predict(emb_pb)
135
  y_bfd = mlp_bfd.predict(emb_bfd)
136
- y_esm = mlp_esm.predict(emb_esm)[:, :597]
137
 
138
- # β€”β€”β€” Stacking β€”β€”β€” #
139
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
140
  y_ens = stacking.predict(X)
141
 
142
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” RESULTADOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
143
- def mostrar_resultados(tag, y_pred):
144
  with st.expander(tag, expanded=True):
145
- hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
146
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
 
147
  if hits:
148
  for go_id in hits:
149
  name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", ""))
150
- limp = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1', defin or "")
 
151
  st.write(f"**{go_id} β€” {name}**")
152
- st.caption(limp)
153
  else:
154
  st.code("β€” nenhum β€”")
155
 
 
156
  st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
157
  for idx in np.argsort(-y_pred[0])[:TOP_N]:
158
  go_id = GO[idx]
159
  name, _ = GO_INFO.get(go_id, ("", ""))
160
  st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
161
 
162
- # Mostrar apenas ensemble (descomenta se quiseres os individuais)
163
- # mostrar_resultados(f"{header} β€” ProtBERT", y_pb)
164
- # mostrar_resultados(f"{header} β€” ProtBERT-BFD", y_bfd)
165
- # mostrar_resultados(f"{header} β€” ESM-2", y_esm)
166
- mostrar_resultados(header, y_ens)
 
 
1
  # -------------------------------------------------------------------------------------------------
2
  # app.py – Streamlit app para prediΓ§Γ£o de GO:MF
3
+ # β€’ ProtBERT / ProtBERT-BFD fine-tuned (melvinalves/FineTune)
4
+ # β€’ ESM-2 base (facebook/esm2_t33_650M_UR50D)
5
  # -------------------------------------------------------------------------------------------------
6
  import os, re, numpy as np, torch, joblib, streamlit as st
7
  from huggingface_hub import login
 
9
  from keras.models import load_model
10
  from goatools.obo_parser import GODag
11
 
12
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” AUTENTICAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
13
  login(os.environ["HF_TOKEN"])
14
 
15
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
16
  SPACE_ID = "melvinalves/protein_function_prediction"
17
  TOP_N = 10
18
  THRESH = 0.37
19
+ CHUNK_PB = 512 # janela ProtBERT / ProtBERT-BFD
20
+ CHUNK_ESM = 1024 # janela ESM-2
21
 
22
+ # repositΓ³rios HF
23
  FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
24
  FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd")
25
  BASE_ESM = "facebook/esm2_t33_650M_UR50D"
 
27
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
28
  @st.cache_resource
29
  def download_file(path):
30
+ """Ficheiros pequenos (≀1 GB) guardados no Space."""
31
  from huggingface_hub import hf_hub_download
32
  return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
33
 
34
  @st.cache_resource
35
  def load_keras(name):
36
+ """Carrega modelos Keras (MLPs e stacking)."""
37
  return load_model(download_file(f"models/{name}"), compile=False)
38
 
39
+ # ---------- carregar tokenizer + encoder ----------
40
  @st.cache_resource
41
+ def load_hf_encoder(repo_id, subfolder=None, base_tok=None):
42
  """
43
+ β€’ repo_id : repositΓ³rio HF ou caminho local
44
+ β€’ subfolder : subpasta onde vivem pesos/config (None se nΓ£o houver)
45
+ β€’ base_tok : repo para o tokenizer (None => usa repo_id)
46
+ Converte tf_model.h5 β†’ PyTorch on-the-fly (from_tf=True).
47
  """
48
+ if base_tok is None:
49
+ base_tok = repo_id
50
  tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False)
51
+
52
+ kwargs = dict(from_tf=True)
53
+ if subfolder:
54
+ kwargs["subfolder"] = subfolder
55
+ mdl = AutoModel.from_pretrained(repo_id, **kwargs)
56
  mdl.eval()
57
  return tok, mdl
58
 
59
+ # ---------- extrair embedding ----------
60
  def embed_seq(model_ref, seq, chunk):
61
  """
62
+ β€’ model_ref = string (modelo base) OU tuple(repo_id, subfolder) (modelo fine-tuned)
63
+ Retorna embedding CLS mΓ©dio (caso a sequΓͺncia seja dividida em chunks).
64
  """
65
+ if isinstance(model_ref, tuple): # ProtBERT / ProtBERT-BFD fine-tuned
66
+ repo_id, subf = model_ref
67
+ tok, mdl = load_hf_encoder(repo_id, subfolder=subf,
68
+ base_tok="Rostlab/prot_bert")
69
+ else: # modelo base (ESM-2)
70
+ tok, mdl = load_hf_encoder(model_ref)
71
 
72
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
73
+ vecs = []
74
  for p in parts:
75
+ toks = tok(" ".join(p), return_tensors="pt", truncation=False)
76
  with torch.no_grad():
77
+ out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()})
78
  vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy())
79
  return np.mean(vecs, axis=0, keepdims=True)
80
 
81
  @st.cache_resource
82
  def load_go_info():
83
+ """LΓͺ GO.obo e devolve dicionΓ‘rio id β†’ (name, definition)."""
84
  obo_path = download_file("data/go.obo")
85
  dag = GODag(obo_path, optional_attrs=["defn"])
86
  return {tid: (term.name, term.defn) for tid, term in dag.items()}
 
99
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
100
  st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓ­nas")
101
 
102
+ # Pequeno ajuste de fonte no textarea
103
+ st.markdown("<style> textarea { font-size: 0.9rem !important; } </style>",
104
+ unsafe_allow_html=True)
 
105
 
106
  fasta_input = st.text_area("Insere uma ou mais sequΓͺncias FASTA:", height=300)
107
  predict_clicked = st.button("Prever GO terms")
108
 
109
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PARSE DE MÚLTIPLAS SEQUÊNCIAS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
110
  def parse_fasta_multiple(fasta_str):
111
+ """
112
+ Devolve lista de (header, seq) a partir de texto FASTA possivelmente mΓΊltiplo.
113
+ Suporta bloco inicial sem '>'.
114
+ """
115
  entries, parsed = fasta_str.strip().split(">"), []
116
  for i, entry in enumerate(entries):
117
  if not entry.strip():
118
  continue
119
  lines = entry.strip().splitlines()
120
+ if i > 0: # bloco tΓ­pico FASTA
121
  header = lines[0].strip()
122
  seq = "".join(lines[1:]).replace(" ", "").upper()
123
+ else: # sequΓͺncia sem '>'
124
  header = f"Seq_{i+1}"
125
  seq = "".join(lines).replace(" ", "").upper()
126
  if seq:
 
136
 
137
  for header, seq in parsed_seqs:
138
  with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
139
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDINGS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
140
  emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
141
  emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
142
  emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
143
 
144
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PREDIÇÕES MLPs β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
145
  y_pb = mlp_pb.predict(emb_pb)
146
  y_bfd = mlp_bfd.predict(emb_bfd)
147
+ y_esm = mlp_esm.predict(emb_esm)[:, :597] # alinhar nΒΊ de termos
148
 
149
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” STACKING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
150
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
151
  y_ens = stacking.predict(X)
152
 
153
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” RESULTADOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
154
+ def mostrar(tag, y_pred):
155
  with st.expander(tag, expanded=True):
156
+ # GO terms acima do threshold
157
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
158
+ hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
159
  if hits:
160
  for go_id in hits:
161
  name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", ""))
162
+ defin = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1',
163
+ defin or "")
164
  st.write(f"**{go_id} β€” {name}**")
165
+ st.caption(defin)
166
  else:
167
  st.code("β€” nenhum β€”")
168
 
169
+ # Top-N mais provΓ‘veis
170
  st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
171
  for idx in np.argsort(-y_pred[0])[:TOP_N]:
172
  go_id = GO[idx]
173
  name, _ = GO_INFO.get(go_id, ("", ""))
174
  st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
175
 
176
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” ESCOLHE QUAIS MOSTRAR β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
177
+ # Descomenta se quiseres ver as saΓ­das individuais
178
+ # mostrar(f"{header} β€” ProtBERT (MLP)", y_pb)
179
+ # mostrar(f"{header} β€” ProtBERT-BFD (MLP)", y_bfd)
180
+ # mostrar(f"{header} β€” ESM-2 (MLP)", y_esm)
181
+ mostrar(header, y_ens) # ensemble