melvinalves commited on
Commit
706ee91
Β·
verified Β·
1 Parent(s): d7b43cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -49
app.py CHANGED
@@ -7,101 +7,87 @@ from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIGURAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
- SPACE_ID = "melvinalves/protein_function_prediction"
12
  TOP_N = 10
13
- THRESH = 0.50 # limiar para listar GO terms
14
  CHUNK_PB = 512
15
  CHUNK_ESM = 1024
16
 
17
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS DE CACHE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
18
  @st.cache_resource
19
  def download_file(path_in_repo: str):
20
- return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path_in_repo)
 
 
 
 
 
 
21
 
22
  @st.cache_resource
23
  def load_keras(file_name: str):
24
- return load_model(download_file(f"models/{file_name}"), compile=False)
 
 
25
 
26
  @st.cache_resource
27
  def load_hf_encoder(model_name: str):
 
28
  tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
29
  mdl = AutoModel.from_pretrained(model_name)
30
  mdl.eval()
31
  return tok, mdl
32
 
33
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
34
  mlp_pb = load_keras("mlp_protbert.h5")
35
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
36
- mlp_esm = load_keras("mlp_esm2.h5")
37
- stacking = load_keras("ensemble_stack.h5") # usa o nome que tiveres guardado
38
 
39
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” LABEL BINARIZER β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
40
- mlb = joblib.load(download_file("data/mlb_597.pkl"))
41
- GO_TERMS = mlb.classes_
42
 
43
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDINGS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
44
- def embed_seq(model_name: str, seq: str, chunk: int) -> np.ndarray:
45
- tok, mdl = load_hf_encoder(model_name)
46
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
47
  vecs = []
48
  for p in parts:
49
  with torch.no_grad():
50
- out = mdl(**tok(" ".join(p), return_tensors="pt", truncation=False))
51
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
52
  return np.mean(vecs, axis=0, keepdims=True)
53
 
54
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
55
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
56
 
57
- st.markdown(
58
- """
59
- <style> textarea { font-size: 0.9rem !important; } </style>
60
- """,
61
- unsafe_allow_html=True,
62
- )
63
-
64
  fasta = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
65
 
66
- # ---------- BOTÃO ----------
67
  if fasta and st.button("Prever GO terms"):
 
68
  seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
69
  seq = seq.replace(" ", "").replace("\n", "").upper()
70
-
71
  if not seq:
72
  st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
73
  st.stop()
74
 
75
- # 1) EMBEDDINGS
76
  st.write("⏳ A gerar embeddings…")
77
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
78
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
- # 2) PREDIÇÕES INDIVIDUAIS
82
  st.write("🧠 A fazer prediΓ§Γ΅es…")
83
  y_pb = mlp_pb.predict(emb_pb)
84
  y_bfd = mlp_bfd.predict(emb_bfd)
85
  y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 β†’ 597
86
 
87
- # 3) ENSEMBLE
88
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
89
- y_ens = stacking.predict(X_stack)
90
-
91
- # β€”β€”β€” FunΓ§Γ£o auxiliar para mostrar resultados β€”β€”β€”
92
- def show_results(label: str, y_pred):
93
- with st.expander(label, expanded=(label == "Ensemble (Stacking)")):
94
- hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
95
- st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
96
- st.code("\n".join(hits) if hits else "β€” nenhum β€”")
97
-
98
- st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
99
- top_idx = np.argsort(-y_pred[0])[:TOP_N]
100
- for i in top_idx:
101
- st.write(f"{GO_TERMS[i]} : {y_pred[0][i]:.4f}")
102
-
103
- # 4) OUTPUT
104
- show_results("ProtBERT (MLP)", y_pb)
105
- show_results("ProtBERT-BFD (MLP)", y_bfd)
106
- show_results("ESM-2 (MLP)", y_esm)
107
- show_results("Ensemble (Stacking)", y_ens)
 
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
+
11
+ SPACE_ID = "melvinalves/protein_function_prediction" # id deste Space
12
  TOP_N = 10
 
13
  CHUNK_PB = 512
14
  CHUNK_ESM = 1024
15
 
16
+
17
  @st.cache_resource
18
  def download_file(path_in_repo: str):
19
+ """Descarrega (e faz cache) um ficheiro do prΓ³prio Space, mesmo que esteja em LFS."""
20
+ local = hf_hub_download(
21
+ repo_id=SPACE_ID,
22
+ repo_type="space",
23
+ filename=path_in_repo,
24
+ )
25
+ return local
26
 
27
  @st.cache_resource
28
  def load_keras(file_name: str):
29
+ """Carrega um modelo Keras (.h5) via hf_hub_download + load_model()."""
30
+ full_path = download_file(f"models/{file_name}")
31
+ return load_model(full_path, compile=False)
32
 
33
  @st.cache_resource
34
  def load_hf_encoder(model_name: str):
35
+ """Carrega tokenizer + encoder HuggingFace (ProtBERT/BFD/ESM)."""
36
  tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
37
  mdl = AutoModel.from_pretrained(model_name)
38
  mdl.eval()
39
  return tok, mdl
40
 
41
+
42
  mlp_pb = load_keras("mlp_protbert.h5")
43
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
44
+ mlp_esm = load_keras("mlp_esm2.h5") # 602 saΓ­das β†’ corta-se p/ 597
45
+ stacking = load_keras("ensemble_stacking.h5") # espera 1791 entradas
46
 
47
+ mlb = joblib.load(download_file("data/mlb_597.pkl"))
48
+ goterms = mlb.classes
 
49
 
50
+ def embed_seq(encoder_name: str, seq: str, chunk: int) -> np.ndarray:
51
+ tok, mdl = load_hf_encoder(encoder_name)
52
+ fmt = lambda s: " ".join(list(s))
53
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
54
  vecs = []
55
  for p in parts:
56
  with torch.no_grad():
57
+ out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
58
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
59
  return np.mean(vecs, axis=0, keepdims=True)
60
 
 
61
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
62
 
 
 
 
 
 
 
 
63
  fasta = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
64
 
 
65
  if fasta and st.button("Prever GO terms"):
66
+ # limpar FASTA
67
  seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
68
  seq = seq.replace(" ", "").replace("\n", "").upper()
 
69
  if not seq:
70
  st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
71
  st.stop()
72
 
 
73
  st.write("⏳ A gerar embeddings…")
74
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
75
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
76
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
77
 
 
78
  st.write("🧠 A fazer prediΓ§Γ΅es…")
79
  y_pb = mlp_pb.predict(emb_pb)
80
  y_bfd = mlp_bfd.predict(emb_bfd)
81
  y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 β†’ 597
82
 
83
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
84
+ y_pred = stacking.predict(X_stack)
85
+
86
+ # β€”β€”β€” Resultados β€”β€”β€”
87
+ st.subheader("GO terms com probabilidade β‰₯ 0.5")
88
+ hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
89
+ st.code("\n".join(hits) if hits else "β€” nenhum β€”")
90
+
91
+ st.subheader(f"Top {TOP_N} GO terms mais provΓ‘veis")
92
+ for idx in np.argsort(-y_pred[0])[:TOP_N]:
93
+ st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")