melvinalves commited on
Commit
c1b30d0
·
verified ·
1 Parent(s): a65bfa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -35
app.py CHANGED
@@ -7,87 +7,101 @@ from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
-
11
- SPACE_ID = "melvinalves/protein_function_prediction" # id deste Space
12
  TOP_N = 10
 
13
  CHUNK_PB = 512
14
  CHUNK_ESM = 1024
15
 
16
-
17
  @st.cache_resource
18
  def download_file(path_in_repo: str):
19
- """Descarrega (e faz cache) um ficheiro do próprio Space, mesmo que esteja em LFS."""
20
- local = hf_hub_download(
21
- repo_id=SPACE_ID,
22
- repo_type="space",
23
- filename=path_in_repo,
24
- )
25
- return local
26
 
27
  @st.cache_resource
28
  def load_keras(file_name: str):
29
- """Carrega um modelo Keras (.h5) via hf_hub_download + load_model()."""
30
- full_path = download_file(f"models/{file_name}")
31
- return load_model(full_path, compile=False)
32
 
33
  @st.cache_resource
34
  def load_hf_encoder(model_name: str):
35
- """Carrega tokenizer + encoder HuggingFace (ProtBERT/BFD/ESM)."""
36
  tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
37
  mdl = AutoModel.from_pretrained(model_name)
38
  mdl.eval()
39
  return tok, mdl
40
 
41
-
42
  mlp_pb = load_keras("mlp_protbert.h5")
43
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
44
- mlp_esm = load_keras("mlp_esm2.h5") # 602 saídas → corta-se p/ 597
45
- stacking = load_keras("ensemble_stack.h5") # espera 1791 entradas
46
 
47
- mlb = joblib.load(download_file("data/mlb_597.pkl"))
48
- goterms = mlb.classes
 
49
 
50
- def embed_seq(encoder_name: str, seq: str, chunk: int) -> np.ndarray:
51
- tok, mdl = load_hf_encoder(encoder_name)
52
- fmt = lambda s: " ".join(list(s))
53
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
54
  vecs = []
55
  for p in parts:
56
  with torch.no_grad():
57
- out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
58
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
59
  return np.mean(vecs, axis=0, keepdims=True)
60
 
 
61
  st.title("🔬 Predição de Funções de Proteínas")
62
 
 
 
 
 
 
 
 
63
  fasta = st.text_area("Insere a sequência FASTA:", height=200)
64
 
 
65
  if fasta and st.button("Prever GO terms"):
66
- # limpar FASTA
67
  seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
68
  seq = seq.replace(" ", "").replace("\n", "").upper()
 
69
  if not seq:
70
  st.warning("Por favor, insere uma sequência válida.")
71
  st.stop()
72
 
 
73
  st.write("⏳ A gerar embeddings…")
74
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
75
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
76
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
77
 
 
78
  st.write("🧠 A fazer predições…")
79
  y_pb = mlp_pb.predict(emb_pb)
80
  y_bfd = mlp_bfd.predict(emb_bfd)
81
  y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 → 597
82
 
83
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
84
- y_pred = stacking.predict(X_stack)
85
-
86
- # ——— Resultados ———
87
- st.subheader("GO terms com probabilidade 0.5")
88
- hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
89
- st.code("\n".join(hits) if hits else " nenhum —")
90
-
91
- st.subheader(f"Top {TOP_N} GO terms mais prováveis")
92
- for idx in np.argsort(-y_pred[0])[:TOP_N]:
93
- st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")
 
 
 
 
 
 
 
 
 
 
 
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
+ # ——————————————————— CONFIGURAÇÃO ——————————————————— #
11
+ SPACE_ID = "melvinalves/protein_function_prediction"
12
  TOP_N = 10
13
+ THRESH = 0.50 # limiar para listar GO terms
14
  CHUNK_PB = 512
15
  CHUNK_ESM = 1024
16
 
17
+ # ——————————————————— HELPERS DE CACHE ——————————————————— #
18
  @st.cache_resource
19
  def download_file(path_in_repo: str):
20
+ return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path_in_repo)
 
 
 
 
 
 
21
 
22
  @st.cache_resource
23
  def load_keras(file_name: str):
24
+ return load_model(download_file(f"models/{file_name}"), compile=False)
 
 
25
 
26
  @st.cache_resource
27
  def load_hf_encoder(model_name: str):
 
28
  tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
29
  mdl = AutoModel.from_pretrained(model_name)
30
  mdl.eval()
31
  return tok, mdl
32
 
33
+ # ——————————————————— MODELOS ——————————————————— #
34
  mlp_pb = load_keras("mlp_protbert.h5")
35
  mlp_bfd = load_keras("mlp_protbertbfd.h5")
36
+ mlp_esm = load_keras("mlp_esm2.h5")
37
+ stacking = load_keras("ensemble_stack.h5") # usa o nome que tiveres guardado
38
 
39
+ # ——————————————————— LABEL BINARIZER ——————————————————— #
40
+ mlb = joblib.load(download_file("data/mlb_597.pkl"))
41
+ GO_TERMS = mlb.classes_
42
 
43
+ # ——————————————————— EMBEDDINGS ——————————————————— #
44
+ def embed_seq(model_name: str, seq: str, chunk: int) -> np.ndarray:
45
+ tok, mdl = load_hf_encoder(model_name)
46
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
47
  vecs = []
48
  for p in parts:
49
  with torch.no_grad():
50
+ out = mdl(**tok(" ".join(p), return_tensors="pt", truncation=False))
51
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
52
  return np.mean(vecs, axis=0, keepdims=True)
53
 
54
+ # ——————————————————— UI ——————————————————— #
55
  st.title("🔬 Predição de Funções de Proteínas")
56
 
57
+ st.markdown(
58
+ """
59
+ <style> textarea { font-size: 0.9rem !important; } </style>
60
+ """,
61
+ unsafe_allow_html=True,
62
+ )
63
+
64
  fasta = st.text_area("Insere a sequência FASTA:", height=200)
65
 
66
+ # ---------- BOTÃO ----------
67
  if fasta and st.button("Prever GO terms"):
 
68
  seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
69
  seq = seq.replace(" ", "").replace("\n", "").upper()
70
+
71
  if not seq:
72
  st.warning("Por favor, insere uma sequência válida.")
73
  st.stop()
74
 
75
+ # 1) EMBEDDINGS
76
  st.write("⏳ A gerar embeddings…")
77
  emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
78
  emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
  emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
+ # 2) PREDIÇÕES INDIVIDUAIS
82
  st.write("🧠 A fazer predições…")
83
  y_pb = mlp_pb.predict(emb_pb)
84
  y_bfd = mlp_bfd.predict(emb_bfd)
85
  y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 → 597
86
 
87
+ # 3) ENSEMBLE
88
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
89
+ y_ens = stacking.predict(X_stack)
90
+
91
+ # ——— Função auxiliar para mostrar resultados ———
92
+ def show_results(label: str, y_pred):
93
+ with st.expander(label, expanded=(label == "Ensemble (Stacking)")):
94
+ hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
95
+ st.markdown(f"**GO terms com prob ≥ {THRESH}**")
96
+ st.code("\n".join(hits) if hits else "— nenhum —")
97
+
98
+ st.markdown(f"**Top {TOP_N} GO terms mais prováveis**")
99
+ top_idx = np.argsort(-y_pred[0])[:TOP_N]
100
+ for i in top_idx:
101
+ st.write(f"{GO_TERMS[i]} : {y_pred[0][i]:.4f}")
102
+
103
+ # 4) OUTPUT
104
+ show_results("ProtBERT (MLP)", y_pb)
105
+ show_results("ProtBERT-BFD (MLP)", y_bfd)
106
+ show_results("ESM-2 (MLP)", y_esm)
107
+ show_results("Ensemble (Stacking)", y_ens)