melvinalves commited on
Commit
bdfb703
Β·
verified Β·
1 Parent(s): 6f54ff6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -42
app.py CHANGED
@@ -8,97 +8,100 @@ from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIGURAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
- SPACE_ID = "melvinalves/protein_function_prediction" # id deste Space
12
  TOP_N = 10
 
13
  CHUNK_PB = 512
14
  CHUNK_ESM = 1024
15
 
16
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS DE CACHE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
17
  @st.cache_resource
18
  def download_file(path_in_repo: str):
19
- """Descarrega (e faz cache) um ficheiro do prΓ³prio Space, mesmo que esteja em LFS."""
20
- local = hf_hub_download(
21
- repo_id=SPACE_ID,
22
- repo_type="space",
23
- filename=path_in_repo,
24
- )
25
- return local
26
 
27
  @st.cache_resource
28
  def load_keras(file_name: str):
29
- """Carrega um modelo Keras (.h5) via hf_hub_download + load_model()."""
30
- full_path = download_file(f"models/{file_name}")
31
- return load_model(full_path, compile=False)
32
 
33
  @st.cache_resource
34
  def load_hf_encoder(model_name: str):
35
- """Carrega tokenizer + encoder HuggingFace (ProtBERT/BFD/ESM)."""
36
  tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
37
  mdl = AutoModel.from_pretrained(model_name)
38
  mdl.eval()
39
  return tok, mdl
40
 
41
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” MODELOS KERAS (.h5) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
42
  mlp_pb = load_keras("mlp_protbert.h5")
43
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
44
- mlp_esm = load_keras("mlp_esm2.h5") # 602 saΓ­das β†’ corta-se p/ 597
45
- stacking = load_keras("ensemble_stacking.h5") # espera 1791 entradas
46
 
47
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” LABEL BINARIZER β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
48
- mlb = joblib.load(download_file("data/mlb_597.pkl"))
49
- go_terms = mlb.classes_
50
 
51
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDING POR CHUNKS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
52
- def embed_seq(encoder_name: str, seq: str, chunk: int) -> np.ndarray:
53
- tok, mdl = load_hf_encoder(encoder_name)
54
- fmt = lambda s: " ".join(list(s))
55
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
56
  vecs = []
57
  for p in parts:
58
  with torch.no_grad():
59
- out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
60
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
61
  return np.mean(vecs, axis=0, keepdims=True)
62
 
63
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” INTERFACE STREAMLIT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
64
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
65
 
66
- st.markdown("""
67
- <style>
68
- textarea {
69
- font-size: 0.9rem !important;
70
- }
71
- </style>
72
- """, unsafe_allow_html=True)
73
 
74
  fasta = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
75
 
 
76
  if fasta and st.button("Prever GO terms"):
77
- # limpar FASTA
78
  seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
79
  seq = seq.replace(" ", "").replace("\n", "").upper()
 
80
  if not seq:
81
  st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
82
  st.stop()
83
 
 
84
  st.write("⏳ A gerar embeddings…")
85
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
86
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
87
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
88
 
 
89
  st.write("🧠 A fazer prediΓ§Γ΅es…")
90
  y_pb = mlp_pb.predict(emb_pb)
91
  y_bfd = mlp_bfd.predict(emb_bfd)
92
  y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 β†’ 597
93
 
94
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
95
- y_pred = stacking.predict(X_stack)
96
-
97
- # β€”β€”β€” Resultados β€”β€”β€”
98
- st.subheader("GO terms com probabilidade β‰₯ 0.5")
99
- hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
100
- st.code("\n".join(hits) if hits else "β€” nenhum β€”")
101
-
102
- st.subheader(f"Top {TOP_N} GO terms mais provΓ‘veis")
103
- for idx in np.argsort(-y_pred[0])[:TOP_N]:
104
- st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")
 
 
 
 
 
 
 
 
 
 
 
8
  from keras.models import load_model
9
 
10
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIGURAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
+ SPACE_ID = "melvinalves/protein_function_prediction"
12
  TOP_N = 10
13
+ THRESH = 0.50 # limiar para listar GO terms
14
  CHUNK_PB = 512
15
  CHUNK_ESM = 1024
16
 
17
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS DE CACHE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
18
  @st.cache_resource
19
  def download_file(path_in_repo: str):
20
+ return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path_in_repo)
 
 
 
 
 
 
21
 
22
  @st.cache_resource
23
  def load_keras(file_name: str):
24
+ return load_model(download_file(f"models/{file_name}"), compile=False)
 
 
25
 
26
  @st.cache_resource
27
  def load_hf_encoder(model_name: str):
 
28
  tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
29
  mdl = AutoModel.from_pretrained(model_name)
30
  mdl.eval()
31
  return tok, mdl
32
 
33
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
34
  mlp_pb = load_keras("mlp_protbert.h5")
35
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
36
+ mlp_esm = load_keras("mlp_esm2.h5")
37
+ stacking = load_keras("ensemble_stack.h5") # usa o nome que tiveres guardado
38
 
39
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” LABEL BINARIZER β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
40
+ mlb = joblib.load(download_file("data/mlb_597.pkl"))
41
+ GO_TERMS = mlb.classes_
42
 
43
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDINGS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
44
+ def embed_seq(model_name: str, seq: str, chunk: int) -> np.ndarray:
45
+ tok, mdl = load_hf_encoder(model_name)
 
46
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
47
  vecs = []
48
  for p in parts:
49
  with torch.no_grad():
50
+ out = mdl(**tok(" ".join(p), return_tensors="pt", truncation=False))
51
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
52
  return np.mean(vecs, axis=0, keepdims=True)
53
 
54
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
55
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
56
 
57
+ st.markdown(
58
+ """
59
+ <style> textarea { font-size: 0.9rem !important; } </style>
60
+ """,
61
+ unsafe_allow_html=True,
62
+ )
 
63
 
64
  fasta = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
65
 
66
+ # ---------- BOTÃO ----------
67
  if fasta and st.button("Prever GO terms"):
 
68
  seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
69
  seq = seq.replace(" ", "").replace("\n", "").upper()
70
+
71
  if not seq:
72
  st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
73
  st.stop()
74
 
75
+ # 1) EMBEDDINGS
76
  st.write("⏳ A gerar embeddings…")
77
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
78
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
+ # 2) PREDIÇÕES INDIVIDUAIS
82
  st.write("🧠 A fazer prediΓ§Γ΅es…")
83
  y_pb = mlp_pb.predict(emb_pb)
84
  y_bfd = mlp_bfd.predict(emb_bfd)
85
  y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 β†’ 597
86
 
87
+ # 3) ENSEMBLE
88
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
89
+ y_ens = stacking.predict(X_stack)
90
+
91
+ # β€”β€”β€” FunΓ§Γ£o auxiliar para mostrar resultados β€”β€”β€”
92
+ def show_results(label: str, y_pred):
93
+ with st.expander(label, expanded=(label == "Ensemble (Stacking)")):
94
+ hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
95
+ st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
96
+ st.code("\n".join(hits) if hits else "β€” nenhum β€”")
97
+
98
+ st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
99
+ top_idx = np.argsort(-y_pred[0])[:TOP_N]
100
+ for i in top_idx:
101
+ st.write(f"{GO_TERMS[i]} : {y_pred[0][i]:.4f}")
102
+
103
+ # 4) OUTPUT
104
+ show_results("ProtBERT (MLP)", y_pb)
105
+ show_results("ProtBERT-BFD (MLP)", y_bfd)
106
+ show_results("ESM-2 (MLP)", y_esm)
107
+ show_results("Ensemble (Stacking)", y_ens)