melvinalves commited on
Commit
d31f1ca
Β·
verified Β·
1 Parent(s): 0cbd3d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -65
app.py CHANGED
@@ -7,96 +7,90 @@ from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
- # ---------- ConfiguraΓ§Γ£o ----------
11
- SPACE_ID = "melvinalves/protein_function_prediction"
12
- TOP_N = 10
13
- CHUNK_PB = 512
14
- CHUNK_ESM = 1024
15
 
16
- # ---------- Cache de downloads ----------
17
  @st.cache_resource
18
- def download_model_file(filename):
19
- local_path = hf_hub_download(
 
20
  repo_id=SPACE_ID,
21
  repo_type="space",
22
- filename=f"models/{filename}",
23
  )
24
- print(f"πŸ“¦ {filename} β†’ {os.path.getsize(local_path)} bytes")
25
- return local_path
26
 
27
  @st.cache_resource
28
- def load_hf_model(model_name):
29
- tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
30
- model = AutoModel.from_pretrained(model_name)
31
- model.eval()
32
- return tokenizer, model
33
 
34
  @st.cache_resource
35
- def load_keras_model(filename):
36
- path = download_model_file(filename)
37
- return load_model(path, compile=False)
38
-
39
- # ---------- Carregar modelos ----------
40
- mlp_pb = load_keras_model("mlp_protbert.keras")
41
- mlp_bfd = load_keras_model("mlp_protbertbfd.keras")
42
- mlp_esm = load_keras_model("mlp_esm2.keras")
43
- stacking = load_keras_model("ensemble_stacking.keras")
44
-
45
- # ---------- Carregar MultiLabelBinarizer ----------
46
- mlb = joblib.load(hf_hub_download(
47
- repo_id=SPACE_ID,
48
- repo_type="space",
49
- filename="data/mlb_597.pkl"
50
- ))
51
  go_terms = mlb.classes_
52
 
53
- # ---------- FunΓ§Γ£o para gerar embeddings ----------
54
- def embed_sequence(model_name, seq, chunk_size):
55
- tokenizer, model = load_hf_model(model_name)
56
- def format_seq(s): return " ".join(list(s))
57
- chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]
58
- embeddings = []
59
- for chunk in chunks:
60
- formatted = format_seq(chunk)
61
- inputs = tokenizer(formatted, return_tensors="pt", truncation=True)
62
  with torch.no_grad():
63
- outputs = model(**inputs)
64
- cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
65
- embeddings.append(cls_embedding)
66
- return np.mean(embeddings, axis=0, keepdims=True)
67
 
68
- # ---------- Interface Streamlit ----------
69
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
70
 
71
- user_input = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
72
 
73
- if user_input and st.button("Prever GO terms"):
74
- # Limpar sequΓͺncia FASTA
75
- sequence = "\n".join([line for line in user_input.splitlines() if not line.startswith(">")])
76
- sequence = sequence.replace(" ", "").replace("\n", "").strip().upper()
77
- if not sequence:
78
  st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
79
  st.stop()
80
 
81
  st.write("⏳ A gerar embeddings…")
82
- emb_pb = embed_sequence("Rostlab/prot_bert", sequence, CHUNK_PB)
83
- emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", sequence, CHUNK_PB)
84
- emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", sequence, CHUNK_ESM)
85
 
86
- st.write("🧠 A fazer prediçáes...")
87
  y_pb = mlp_pb.predict(emb_pb)
88
  y_bfd = mlp_bfd.predict(emb_bfd)
89
- y_esm = mlp_esm.predict(emb_esm)[:, :597] # garantir alinhamento
90
 
91
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
92
- y_pred = stacking.predict(X_stack)
93
 
94
- st.subheader("🎯 GO terms com probabilidade β‰₯ 0.5")
 
95
  hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
96
  st.code("\n".join(hits) if hits else "β€” nenhum β€”")
97
 
98
- st.subheader(f"⭐ Top {TOP_N} GO terms mais provÑveis")
99
- top_idx = np.argsort(-y_pred[0])[:TOP_N]
100
- for i in top_idx:
101
- st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
102
-
 
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIGURAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
+ SPACE_ID = "melvinalves/protein_function_prediction" # id deste Space
12
+ TOP_N = 10
13
+ CHUNK_PB = 512
14
+ CHUNK_ESM = 1024
15
 
16
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS DE CACHE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
17
  @st.cache_resource
18
+ def download_file(path_in_repo: str):
19
+ """Descarrega (e faz cache) um ficheiro do prΓ³prio Space, mesmo que esteja em LFS."""
20
+ local = hf_hub_download(
21
  repo_id=SPACE_ID,
22
  repo_type="space",
23
+ filename=path_in_repo,
24
  )
25
+ return local
 
26
 
27
  @st.cache_resource
28
+ def load_keras(file_name: str):
29
+ """Carrega um modelo Keras (.h5) via hf_hub_download + load_model()."""
30
+ full_path = download_file(f"models/{file_name}")
31
+ return load_model(full_path, compile=False)
 
32
 
33
  @st.cache_resource
34
+ def load_hf_encoder(model_name: str):
35
+ """Carrega tokenizer + encoder HuggingFace (ProtBERT/BFD/ESM)."""
36
+ tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
37
+ mdl = AutoModel.from_pretrained(model_name)
38
+ mdl.eval()
39
+ return tok, mdl
40
+
41
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” MODELOS KERAS (.h5) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
42
+ mlp_pb = load_keras("mlp_protbert.h5")
43
+ mlp_bfd = load_keras("mlp_protbertbfd.h5")
44
+ mlp_esm = load_keras("mlp_esm2.h5") # 602 saΓ­das β†’ corta-se p/ 597
45
+ stacking = load_keras("ensemble_stacking.h5") # espera 1791 entradas
46
+
47
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” LABEL BINARIZER β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
48
+ mlb = joblib.load(download_file("data/mlb_597.pkl"))
 
49
  go_terms = mlb.classes_
50
 
51
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDING POR CHUNKS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
52
+ def embed_seq(encoder_name: str, seq: str, chunk: int) -> np.ndarray:
53
+ tok, mdl = load_hf_encoder(encoder_name)
54
+ fmt = lambda s: " ".join(list(s))
55
+ parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
56
+ vecs = []
57
+ for p in parts:
 
 
58
  with torch.no_grad():
59
+ out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
60
+ vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
61
+ return np.mean(vecs, axis=0, keepdims=True)
 
62
 
63
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” INTERFACE STREAMLIT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
64
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
65
 
66
+ fasta = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
67
 
68
+ if fasta and st.button("Prever GO terms"):
69
+ # limpar FASTA
70
+ seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
71
+ seq = seq.replace(" ", "").replace("\n", "").upper()
72
+ if not seq:
73
  st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
74
  st.stop()
75
 
76
  st.write("⏳ A gerar embeddings…")
77
+ emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
78
+ emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
+ emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
+ st.write("🧠 A fazer prediΓ§Γ΅es…")
82
  y_pb = mlp_pb.predict(emb_pb)
83
  y_bfd = mlp_bfd.predict(emb_bfd)
84
+ y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 β†’ 597
85
 
86
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
87
+ y_pred = stacking.predict(X_stack)
88
 
89
+ # β€”β€”β€” Resultados β€”β€”β€”
90
+ st.subheader("GO terms com probabilidade β‰₯ 0.5")
91
  hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
92
  st.code("\n".join(hits) if hits else "β€” nenhum β€”")
93
 
94
+ st.subheader(f"Top {TOP_N} GO terms mais provΓ‘veis")
95
+ for idx in np.argsort(-y_pred[0])[:TOP_N]:
96
+ st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")