melvinalves commited on
Commit
201f653
·
verified ·
1 Parent(s): 0ffe325

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -64
app.py CHANGED
@@ -1,93 +1,102 @@
1
- import os, pathlib
2
  import numpy as np
3
- import torch, joblib, streamlit as st
 
 
4
  from transformers import AutoTokenizer, AutoModel
5
  from huggingface_hub import hf_hub_download
6
- from tensorflow.keras.models import load_model
7
 
8
  # ---------- Configuração ----------
9
- SPACE_ID = "melvinalves/protein_function_prediction" # id deste Space
10
- TOP_N = 10
11
- CHUNK_PB = 512
12
- CHUNK_ESM = 1024
13
 
14
- # ---------- Helpers ----------
15
  @st.cache_resource
16
- def hf_cached(fname: str):
17
- """Faz download de um ficheiro (LFS ou não) e devolve o caminho local."""
18
- return hf_hub_download(
19
- repo_id = SPACE_ID,
20
- repo_type = "space",
21
- filename = fname, # ex: "models/mlp_protbert.keras"
22
  )
 
 
23
 
24
  @st.cache_resource
25
- def load_hf_model(model_name: str):
26
- tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
27
- mdl = AutoModel.from_pretrained(model_name)
28
- mdl.eval()
29
- return tok, mdl
30
 
31
  @st.cache_resource
32
- def load_local_model(fname: str):
33
- full_path = hf_cached(f"models/{fname}")
34
- print(">>", pathlib.Path(full_path).name,
35
- "tamanho:", os.path.getsize(full_path), "bytes")
36
- return load_model(full_path, compile=False)
37
-
38
- # ---------- Carregamento dos modelos locais (.keras) ----------
39
- mlp_pb = load_local_model("mlp_protbert.keras")
40
- mlp_bfd = load_local_model("mlp_protbertbfd.keras")
41
- mlp_esm = load_local_model("mlp_esm2.keras") # 602 saídas – cortamos depois
42
- stacking = load_local_model("ensemble_stacking.keras") # espera 1791 entradas
43
-
44
- # ---------- Label binarizer ----------
45
- mlb = joblib.load(hf_cached("data/mlb_597.pkl"))
 
 
46
  go_terms = mlb.classes_
47
 
48
- # ---------- Função de embedding ----------
49
- def embed_sequence(model_name: str, seq: str, chunk: int) -> np.ndarray:
50
- tok, mdl = load_hf_model(model_name)
51
- fmt = lambda s: " ".join(list(s))
52
- parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
53
- vecs = []
54
- for p in parts:
 
 
55
  with torch.no_grad():
56
- out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
57
- vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
58
- return np.mean(vecs, axis=0, keepdims=True)
 
59
 
60
- # ---------- Interface ----------
61
  st.title("🔬 Predição de Funções de Proteínas")
62
 
63
- src = st.text_area("Insere a sequência FASTA:", height=200)
64
 
65
- if src and st.button("Prever GO terms"):
66
- # limpar FASTA
67
- seq = "\n".join(l for l in src.splitlines() if not l.startswith(">"))
68
- seq = seq.replace(" ", "").replace("\n", "").upper()
69
- if not seq:
70
- st.warning("Sequência vazia.")
71
  st.stop()
72
 
73
  st.write("⏳ A gerar embeddings…")
74
- emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
75
- emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
76
- emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
77
 
78
- st.write("🧠 A prever com cada modelo…")
79
  y_pb = mlp_pb.predict(emb_pb)
80
  y_bfd = mlp_bfd.predict(emb_bfd)
81
- y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602→597 para alinhar
82
 
83
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
84
- y_pred = stacking.predict(X_stack)
85
 
86
- # ---------- Resultados ----------
87
- st.subheader("GO terms com probabilidade ≥ 0.5")
88
  hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
89
- st.code("\n".join(hits) or "— nenhum —")
 
 
 
 
 
90
 
91
- st.subheader(f"Top {TOP_N} GO terms mais prováveis")
92
- for idx in np.argsort(-y_pred[0])[:TOP_N]:
93
- st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")
 
1
+ import os
2
  import numpy as np
3
+ import torch
4
+ import joblib
5
+ import streamlit as st
6
  from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
+ from keras.models import load_model
9
 
10
  # ---------- Configuração ----------
11
+ SPACE_ID = "melvinalves/protein_function_prediction"
12
+ TOP_N = 10
13
+ CHUNK_PB = 512
14
+ CHUNK_ESM = 1024
15
 
16
+ # ---------- Cache de downloads ----------
17
  @st.cache_resource
18
+ def download_model_file(filename):
19
+ local_path = hf_hub_download(
20
+ repo_id=SPACE_ID,
21
+ repo_type="space",
22
+ filename=f"models/{filename}",
 
23
  )
24
+ print(f"📦 {filename} → {os.path.getsize(local_path)} bytes")
25
+ return local_path
26
 
27
  @st.cache_resource
28
+ def load_hf_model(model_name):
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
30
+ model = AutoModel.from_pretrained(model_name)
31
+ model.eval()
32
+ return tokenizer, model
33
 
34
  @st.cache_resource
35
+ def load_keras_model(filename):
36
+ path = download_model_file(filename)
37
+ return load_model(path, compile=False)
38
+
39
+ # ---------- Carregar modelos ----------
40
+ mlp_pb = load_keras_model("mlp_protbert.keras")
41
+ mlp_bfd = load_keras_model("mlp_protbertbfd.keras")
42
+ mlp_esm = load_keras_model("mlp_esm2.keras")
43
+ stacking = load_keras_model("ensemble_stacking.keras")
44
+
45
+ # ---------- Carregar MultiLabelBinarizer ----------
46
+ mlb = joblib.load(hf_hub_download(
47
+ repo_id=SPACE_ID,
48
+ repo_type="space",
49
+ filename="data/mlb_597.pkl"
50
+ ))
51
  go_terms = mlb.classes_
52
 
53
+ # ---------- Função para gerar embeddings ----------
54
+ def embed_sequence(model_name, seq, chunk_size):
55
+ tokenizer, model = load_hf_model(model_name)
56
+ def format_seq(s): return " ".join(list(s))
57
+ chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]
58
+ embeddings = []
59
+ for chunk in chunks:
60
+ formatted = format_seq(chunk)
61
+ inputs = tokenizer(formatted, return_tensors="pt", truncation=True)
62
  with torch.no_grad():
63
+ outputs = model(**inputs)
64
+ cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
65
+ embeddings.append(cls_embedding)
66
+ return np.mean(embeddings, axis=0, keepdims=True)
67
 
68
+ # ---------- Interface Streamlit ----------
69
  st.title("🔬 Predição de Funções de Proteínas")
70
 
71
+ user_input = st.text_area("Insere a sequência FASTA:", height=200)
72
 
73
+ if user_input and st.button("Prever GO terms"):
74
+ # Limpar sequência FASTA
75
+ sequence = "\n".join([line for line in user_input.splitlines() if not line.startswith(">")])
76
+ sequence = sequence.replace(" ", "").replace("\n", "").strip().upper()
77
+ if not sequence:
78
+ st.warning("Por favor, insere uma sequência válida.")
79
  st.stop()
80
 
81
  st.write("⏳ A gerar embeddings…")
82
+ emb_pb = embed_sequence("Rostlab/prot_bert", sequence, CHUNK_PB)
83
+ emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", sequence, CHUNK_PB)
84
+ emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", sequence, CHUNK_ESM)
85
 
86
+ st.write("🧠 A fazer predições...")
87
  y_pb = mlp_pb.predict(emb_pb)
88
  y_bfd = mlp_bfd.predict(emb_bfd)
89
+ y_esm = mlp_esm.predict(emb_esm)[:, :597] # garantir alinhamento
90
 
91
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
92
+ y_pred = stacking.predict(X_stack)
93
 
94
+ st.subheader("🎯 GO terms com probabilidade ≥ 0.5")
 
95
  hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
96
+ st.code("\n".join(hits) if hits else "— nenhum —")
97
+
98
+ st.subheader(f"⭐ Top {TOP_N} GO terms mais prováveis")
99
+ top_idx = np.argsort(-y_pred[0])[:TOP_N]
100
+ for i in top_idx:
101
+ st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
102