melvinalves commited on
Commit
a547fbe
·
verified ·
1 Parent(s): ed229ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -40
app.py CHANGED
@@ -1,92 +1,87 @@
1
  import os
2
  import numpy as np
3
- import torch
4
- import streamlit as st
5
- import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from tensorflow.keras.models import load_model
9
 
10
- # ----------- Config Space -----------
11
- SPACE_REPO = "melvinalves/protein_function_prediction" # <- o teu Space
12
- MODELS_DIR = "models"
13
- DATA_DIR = "data"
14
-
15
  TOP_N = 10
16
  CHUNK_PB = 512
17
  CHUNK_ESM = 1024
18
 
19
- # ----------- Helpers -----------
20
  @st.cache_resource
21
- def hf_cached(path_inside_repo: str):
22
- """Faz download (uma vez) e devolve caminho local."""
23
  return hf_hub_download(
24
- repo_id=SPACE_REPO,
25
- repo_type="space",
26
- filename=path_inside_repo,
27
  )
28
 
29
  @st.cache_resource
30
- def load_hf_model(model_name):
31
- tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
32
- mdl = AutoModel.from_pretrained(model_name); mdl.eval()
 
33
  return tok, mdl
34
 
35
  @st.cache_resource
36
- def load_local_model(file_name):
37
- local_path = hf_cached(f"{MODELS_DIR}/{file_name}")
38
- return load_model(local_path, compile=False)
39
 
40
- # ----------- Carregar modelos (.keras) -----------
41
  mlp_pb = load_local_model("mlp_protbert.keras")
42
  mlp_bfd = load_local_model("mlp_protbertbfd.keras")
43
- mlp_esm = load_local_model("mlp_esm2.keras")
44
- stacking = load_local_model("ensemble_stacking.keras")
45
 
46
- # ----------- MultiLabelBinarizer -----------
47
- mlb_path = hf_cached(f"{DATA_DIR}/mlb_597.pkl")
48
- mlb = joblib.load(mlb_path)
49
  go_terms = mlb.classes_
50
 
51
- # ----------- Embedding por chunks -----------
52
  def embed_sequence(model_name: str, seq: str, chunk: int) -> np.ndarray:
53
  tok, mdl = load_hf_model(model_name)
54
  fmt = lambda s: " ".join(list(s))
55
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
56
- vecs = []
57
  for p in parts:
58
  with torch.no_grad():
59
  out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
60
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
61
  return np.mean(vecs, axis=0, keepdims=True)
62
 
63
- # ----------- UI -----------
64
- st.title("Predição de Funções de Proteínas 🔬")
65
 
66
- fa_input = st.text_area("Insere a sequência FASTA:", height=200)
67
 
68
- if fa_input and st.button("Prever GO terms"):
69
- # Limpa FASTA
70
- seq = "\n".join(l for l in fa_input.splitlines() if not l.startswith(">"))
71
  seq = seq.replace(" ", "").replace("\n", "").upper()
72
  if not seq:
73
  st.warning("Sequência vazia.")
74
  st.stop()
75
 
76
- st.write("🔄 A gerar embeddings…")
77
  emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
78
  emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
  emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
- st.write("🧠 A fazer predições…")
82
  y_pb = mlp_pb.predict(emb_pb)
83
  y_bfd = mlp_bfd.predict(emb_bfd)
84
- y_esm = mlp_esm.predict(emb_esm)[:, :597] # garante 597 colunas
85
 
86
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
87
  y_pred = stacking.predict(X_stack)
88
 
89
- # ----------- Output -----------
90
  st.subheader("GO terms com probabilidade ≥ 0.5")
91
  hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
92
  st.code("\n".join(hits) or "— nenhum —")
@@ -94,4 +89,3 @@ if fa_input and st.button("Prever GO terms"):
94
  st.subheader(f"Top {TOP_N} GO terms mais prováveis")
95
  for idx in np.argsort(-y_pred[0])[:TOP_N]:
96
  st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")
97
-
 
1
  import os
2
  import numpy as np
3
+ import torch, joblib, streamlit as st
 
 
4
  from transformers import AutoTokenizer, AutoModel
5
  from huggingface_hub import hf_hub_download
6
  from tensorflow.keras.models import load_model
7
 
8
+ # ---------- Configuração ----------
9
+ SPACE_ID = "melvinalves/protein_function_prediction" # id deste Space
 
 
 
10
  TOP_N = 10
11
  CHUNK_PB = 512
12
  CHUNK_ESM = 1024
13
 
14
+ # ---------- Helpers ----------
15
  @st.cache_resource
16
+ def hf_cached(fname: str):
17
+ """Faz download de um ficheiro (LFS ou não) e devolve o caminho local."""
18
  return hf_hub_download(
19
+ repo_id = SPACE_ID,
20
+ repo_type = "space",
21
+ filename = fname, # ex: "models/mlp_protbert.keras"
22
  )
23
 
24
  @st.cache_resource
25
+ def load_hf_model(model_name: str):
26
+ tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
27
+ mdl = AutoModel.from_pretrained(model_name)
28
+ mdl.eval()
29
  return tok, mdl
30
 
31
  @st.cache_resource
32
+ def load_local_model(fname: str):
33
+ full_path = hf_cached(f"models/{fname}")
34
+ return load_model(full_path, compile=False)
35
 
36
+ # ---------- Carregamento dos modelos locais (.keras) ----------
37
  mlp_pb = load_local_model("mlp_protbert.keras")
38
  mlp_bfd = load_local_model("mlp_protbertbfd.keras")
39
+ mlp_esm = load_local_model("mlp_esm2.keras") # 602 saídas – cortamos depois
40
+ stacking = load_local_model("ensemble_stacking.keras") # espera 1791 entradas
41
 
42
+ # ---------- Label binarizer ----------
43
+ mlb = joblib.load(hf_cached("data/mlb_597.pkl"))
 
44
  go_terms = mlb.classes_
45
 
46
+ # ---------- Função de embedding ----------
47
  def embed_sequence(model_name: str, seq: str, chunk: int) -> np.ndarray:
48
  tok, mdl = load_hf_model(model_name)
49
  fmt = lambda s: " ".join(list(s))
50
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
51
+ vecs = []
52
  for p in parts:
53
  with torch.no_grad():
54
  out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
55
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
56
  return np.mean(vecs, axis=0, keepdims=True)
57
 
58
+ # ---------- Interface ----------
59
+ st.title("🔬 Predição de Funções de Proteínas")
60
 
61
+ src = st.text_area("Insere a sequência FASTA:", height=200)
62
 
63
+ if src and st.button("Prever GO terms"):
64
+ # limpar FASTA
65
+ seq = "\n".join(l for l in src.splitlines() if not l.startswith(">"))
66
  seq = seq.replace(" ", "").replace("\n", "").upper()
67
  if not seq:
68
  st.warning("Sequência vazia.")
69
  st.stop()
70
 
71
+ st.write(" A gerar embeddings…")
72
  emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
73
  emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
74
  emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
75
 
76
+ st.write("🧠 A prever com cada modelo…")
77
  y_pb = mlp_pb.predict(emb_pb)
78
  y_bfd = mlp_bfd.predict(emb_bfd)
79
+ y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602→597 para alinhar
80
 
81
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
82
  y_pred = stacking.predict(X_stack)
83
 
84
+ # ---------- Resultados ----------
85
  st.subheader("GO terms com probabilidade ≥ 0.5")
86
  hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
87
  st.code("\n".join(hits) or "— nenhum —")
 
89
  st.subheader(f"Top {TOP_N} GO terms mais prováveis")
90
  for idx in np.argsort(-y_pred[0])[:TOP_N]:
91
  st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")