melvinalves commited on
Commit
c843aa7
·
verified ·
1 Parent(s): 062e02f

Upload 13 files

Browse files
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from tensorflow.keras.models import load_model
6
+ import joblib
7
+ import streamlit as st
8
+
9
+ # ---------- Caminhos ----------
10
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ MODELS_DIR = os.path.join(BASE_DIR, "models")
12
+ MLB_PATH = os.path.join(BASE_DIR, "data", "mlb_597.pkl")
13
+
14
+ # ---------- Parâmetros ----------
15
+ TOP_N = 10
16
+ CHUNK_PB = 512
17
+ CHUNK_ESM = 1024
18
+
19
+ # ---------- Cache dos modelos HuggingFace ----------
20
+ @st.cache_resource
21
+ def load_hf_model(name):
22
+ tokenizer = AutoTokenizer.from_pretrained(name, do_lower_case=False)
23
+ model = AutoModel.from_pretrained(name)
24
+ model.eval()
25
+ return tokenizer, model
26
+
27
+ # ---------- Função para gerar embedding por chunk ----------
28
+ def embed_sequence(model_name, seq, chunk_size):
29
+ tokenizer, model = load_hf_model(model_name)
30
+
31
+ def format_seq(s):
32
+ return " ".join(list(s))
33
+
34
+ chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]
35
+ embeddings = []
36
+
37
+ for chunk in chunks:
38
+ formatted = format_seq(chunk)
39
+ inputs = tokenizer(formatted, return_tensors="pt", truncation=True)
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+ cls = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
43
+ embeddings.append(cls)
44
+
45
+ return np.mean(embeddings, axis=0, keepdims=True)
46
+
47
+ # ---------- Carregar modelos ----------
48
+ mlp_pb = load_model(os.path.join(MODELS_DIR, "mlp_protbert.h5"), compile=False)
49
+ mlp_bfd = load_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.h5"), compile=False)
50
+ mlp_esm = load_model(os.path.join(MODELS_DIR, "mlp_esm2.h5"), compile=False)
51
+ stacking = load_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.h5"), compile=False)
52
+
53
+ # ---------- Carregar MultiLabelBinarizer ----------
54
+ mlb = joblib.load(MLB_PATH)
55
+ go_terms = mlb.classes_
56
+
57
+ # ---------- Interface Streamlit ----------
58
+ st.title("Predição de Funções de Proteínas")
59
+
60
+ seq = st.text_area("Insere a sequência FASTA:", height=200)
61
+
62
+ # Limpar sequência: remover cabeçalhos (">") e espaços/quebras
63
+ if seq:
64
+ seq = "\n".join([line for line in seq.splitlines() if not line.startswith(">")])
65
+ seq = seq.replace(" ", "").replace("\n", "").strip()
66
+
67
+ if st.button("Prever GO terms"):
68
+ if not seq:
69
+ st.warning("Por favor, insere uma sequência válida.")
70
+ else:
71
+ st.write("A gerar embeddings por chunks...")
72
+
73
+ emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
74
+ emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
75
+ emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
76
+
77
+ st.write("A fazer predições base...")
78
+
79
+ y_pb = mlp_pb.predict(emb_pb)[:, :597]
80
+ y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]
81
+ y_esm = mlp_esm.predict(emb_esm)[:, :597]
82
+
83
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
84
+ y_pred = stacking.predict(X_stack)
85
+
86
+ st.subheader("GO terms com probabilidade ≥ 0.5:")
87
+ predicted = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
88
+ if predicted:
89
+ st.code("\n".join(predicted))
90
+ else:
91
+ st.info("Nenhum GO term com probabilidade ≥ 0.5.")
92
+
93
+ st.subheader(f"Top {TOP_N} GO terms mais prováveis:")
94
+ top_idx = np.argsort(-y_pred[0])[:TOP_N]
95
+ for i in top_idx:
96
+ st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
97
+
data/mlb_597.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:652f2a315accef73abf13ff733cb24aacda95448fe7f1194a63eca9ca7e7dee5
3
+ size 19564
models/mlp_esm2.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e85d48faa31a582b632f152a48dbb25f84d79e4cb06e4e1bcc9543c92fbf4646
3
+ size 8598312
models/mlp_protbert.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:950ab09eecdb21bedb0d68aae6ed7412722cceac42ed94489887c3e43ecb0598
3
+ size 7539476
models/mlp_protbertbfd.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:562ea1f6667928274a38169d2a25eab49305de0103101599b08404c84a470845
3
+ size 7539476
models/modelo_ensemble_stack.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a6c06c7c16c30841ce54e3cf6b551e21ec5942daba1f77029f4fe752c17caed
3
+ size 4826900
notebooks/Ensemble.ipynb ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "0fbbb46c-1a00-4585-9ecd-a490a46e8b99",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
+ "mlb com 597 GO terms guardado em data/mlb_597.pkl\n",
15
+ " ProtBERT Fmax=0.6616 Thr=0.40 AuPRC=0.7009 Smin=13.9047\n",
16
+ " ProtBERT-BFD Fmax=0.6573 Thr=0.41 AuPRC=0.6925 Smin=13.7060\n",
17
+ " ESM-2 Fmax=0.6375 Thr=0.39 AuPRC=0.6875 Smin=14.1194\n",
18
+ " Ensemble Fmax=0.6864 Thr=0.37 AuPRC=0.7332 Smin=12.7879\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "# %%\n",
24
+ "import numpy as np, joblib, math\n",
25
+ "from sklearn.metrics import precision_recall_curve, auc\n",
26
+ "from goatools.obo_parser import GODag\n",
27
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
28
+ "\n",
29
+ "GO_FILE = \"go.obo\"\n",
30
+ "dag = GODag(GO_FILE)\n",
31
+ "\n",
32
+ "# ---------- 1. y_true + GO terms (referência ProtBERT) ----------\n",
33
+ "test_pb = joblib.load(\"embeddings/test_protbert.pkl\")\n",
34
+ "y_true = test_pb[\"labels\"] # (1724, 597) ← ground-truth\n",
35
+ "go_ref = list(test_pb[\"go_terms\"]) # ordem exacta das colunas\n",
36
+ "\n",
37
+ "n_go = len(go_ref) # 597\n",
38
+ "\n",
39
+ "# --- Recriar o MultiLabelBinarizer com os 597 termos corretos ---\n",
40
+ "mlb = MultiLabelBinarizer(classes=go_ref)\n",
41
+ "mlb.fit([go_ref]) # necessário para permitir inverse_transform depois\n",
42
+ "\n",
43
+ "# ---------- 2. Carregar predições ----------\n",
44
+ "y_pb = np.load(\"predictions/mf-protbert-pam1.npy\") # 1724×597\n",
45
+ "y_bfd = np.load(\"predictions/mf-protbertbfd-pam1.npy\") # 1724×597\n",
46
+ "y_esm0 = np.load(\"predictions/mf-esm2.npy\") # 1724×602\n",
47
+ "\n",
48
+ "# ---------- 3. Remapear ESM-2 para ordem ProtBERT ----------\n",
49
+ "mlb_esm = joblib.load(\"data/mlb.pkl\") # 602 GO terms\n",
50
+ "idx_map = [list(mlb_esm.classes_).index(t) for t in go_ref]\n",
51
+ "y_esm = y_esm0[:, idx_map] # 1724×597\n",
52
+ "\n",
53
+ "# ---------- 4. Garantir shapes iguais ----------\n",
54
+ "assert (y_true.shape == y_pb.shape == y_bfd.shape\n",
55
+ " == y_esm.shape == (1724, n_go)), \"Ainda há desalinhamento!\"\n",
56
+ "\n",
57
+ "# ---------- 4. Guardar mlb (y_true) alinhado ----------\n",
58
+ "joblib.dump(mlb, \"data/mlb_597.pkl\")\n",
59
+ "print(\"mlb com 597 GO terms guardado em data/mlb_597.pkl\")\n",
60
+ "\n",
61
+ "# ---------- 5. Métricas ----------\n",
62
+ "THR = np.linspace(0,1,101)\n",
63
+ "def fmax(y_t,y_p):\n",
64
+ " best,thr = 0,0\n",
65
+ " for t in THR:\n",
66
+ " y_b = (y_p>=t).astype(int)\n",
67
+ " tp = (y_t*y_b).sum(1); fp=((1-y_t)*y_b).sum(1); fn=(y_t*(1-y_b)).sum(1)\n",
68
+ " f1 = 2*tp/(2*tp+fp+fn+1e-8); m=f1.mean()\n",
69
+ " if m>best: best,thr = m,t\n",
70
+ " return best,thr\n",
71
+ "\n",
72
+ "def auprc(y_t,y_p):\n",
73
+ " p,r,_ = precision_recall_curve(y_t.ravel(), y_p.ravel()); return auc(r,p)\n",
74
+ "\n",
75
+ "def smin(y_t,y_p,thr,alpha=0.5):\n",
76
+ " y_b=(y_p>=thr).astype(int)\n",
77
+ " ic=-(np.log((y_t+y_b).sum(0)+1e-8)-np.log((y_t+y_b).sum()+1e-8))\n",
78
+ " ru=np.logical_and(y_b, np.logical_not(y_t))*ic\n",
79
+ " mi=np.logical_and(y_t, np.logical_not(y_b))*ic\n",
80
+ " return np.sqrt((alpha*ru.sum(1))**2 + ((1-alpha)*mi.sum(1))**2).mean()\n",
81
+ "\n",
82
+ "def show(name,y_p):\n",
83
+ " f,thr=fmax(y_true,y_p)\n",
84
+ " print(f\"{name:>13s} Fmax={f:.4f} Thr={thr:.2f} \"\n",
85
+ " f\"AuPRC={auprc(y_true,y_p):.4f} Smin={smin(y_true,y_p,thr):.4f}\")\n",
86
+ "\n",
87
+ "show(\"ProtBERT\", y_pb)\n",
88
+ "show(\"ProtBERT-BFD\", y_bfd)\n",
89
+ "show(\"ESM-2\", y_esm)\n",
90
+ "show(\"Ensemble\", (y_pb + y_bfd + y_esm)/3)\n",
91
+ "\n"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 9,
97
+ "id": "f1807404-c2ce-48d0-b87c-a7e0fecc1728",
98
+ "metadata": {},
99
+ "outputs": [
100
+ {
101
+ "name": "stdout",
102
+ "output_type": "stream",
103
+ "text": [
104
+ "Epoch 1/50\n",
105
+ "19/19 [==============================] - 0s 17ms/step - loss: 0.3811 - val_loss: 0.0868\n",
106
+ "Epoch 2/50\n",
107
+ "19/19 [==============================] - 0s 8ms/step - loss: 0.0882 - val_loss: 0.0696\n",
108
+ "Epoch 3/50\n",
109
+ "19/19 [==============================] - 0s 6ms/step - loss: 0.0628 - val_loss: 0.0563\n",
110
+ "Epoch 4/50\n",
111
+ "19/19 [==============================] - 0s 6ms/step - loss: 0.0552 - val_loss: 0.0520\n",
112
+ "Epoch 5/50\n",
113
+ "19/19 [==============================] - 0s 5ms/step - loss: 0.0507 - val_loss: 0.0486\n",
114
+ "Epoch 6/50\n",
115
+ "19/19 [==============================] - 0s 5ms/step - loss: 0.0473 - val_loss: 0.0455\n",
116
+ "Epoch 7/50\n",
117
+ "19/19 [==============================] - 0s 5ms/step - loss: 0.0437 - val_loss: 0.0431\n",
118
+ "Epoch 8/50\n",
119
+ "19/19 [==============================] - 0s 8ms/step - loss: 0.0414 - val_loss: 0.0414\n",
120
+ "Epoch 9/50\n",
121
+ "19/19 [==============================] - 0s 7ms/step - loss: 0.0391 - val_loss: 0.0395\n",
122
+ "Epoch 10/50\n",
123
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0371 - val_loss: 0.0383\n",
124
+ "Epoch 11/50\n",
125
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0355 - val_loss: 0.0372\n",
126
+ "Epoch 12/50\n",
127
+ "19/19 [==============================] - 0s 4ms/step - loss: 0.0341 - val_loss: 0.0362\n",
128
+ "Epoch 13/50\n",
129
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0329 - val_loss: 0.0352\n",
130
+ "Epoch 14/50\n",
131
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0318 - val_loss: 0.0345\n",
132
+ "Epoch 15/50\n",
133
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0307 - val_loss: 0.0341\n",
134
+ "Epoch 16/50\n",
135
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0300 - val_loss: 0.0337\n",
136
+ "Epoch 17/50\n",
137
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0291 - val_loss: 0.0334\n",
138
+ "Epoch 18/50\n",
139
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0281 - val_loss: 0.0331\n",
140
+ "Epoch 19/50\n",
141
+ "19/19 [==============================] - 0s 4ms/step - loss: 0.0278 - val_loss: 0.0329\n",
142
+ "Epoch 20/50\n",
143
+ "19/19 [==============================] - 0s 5ms/step - loss: 0.0270 - val_loss: 0.0328\n",
144
+ "Epoch 21/50\n",
145
+ "19/19 [==============================] - 0s 4ms/step - loss: 0.0266 - val_loss: 0.0329\n",
146
+ "Epoch 22/50\n",
147
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0261 - val_loss: 0.0326\n",
148
+ "Epoch 23/50\n",
149
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0257 - val_loss: 0.0325\n",
150
+ "Epoch 24/50\n",
151
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0249 - val_loss: 0.0324\n",
152
+ "Epoch 25/50\n",
153
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0247 - val_loss: 0.0325\n",
154
+ "Epoch 26/50\n",
155
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0243 - val_loss: 0.0322\n",
156
+ "Epoch 27/50\n",
157
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0239 - val_loss: 0.0325\n",
158
+ "Epoch 28/50\n",
159
+ "19/19 [==============================] - 0s 3ms/step - loss: 0.0235 - val_loss: 0.0325\n",
160
+ "Epoch 29/50\n",
161
+ "19/19 [==============================] - 0s 5ms/step - loss: 0.0226 - val_loss: 0.0328\n",
162
+ "Epoch 30/50\n",
163
+ "19/19 [==============================] - 0s 5ms/step - loss: 0.0227 - val_loss: 0.0326\n",
164
+ "Epoch 31/50\n",
165
+ "19/19 [==============================] - 0s 4ms/step - loss: 0.0224 - val_loss: 0.0326\n",
166
+ "\n",
167
+ " STACKING (GPU-Keras MLP)\n",
168
+ "Fmax = 0.6937\n",
169
+ "Thr. = 0.34\n",
170
+ "AuPRC = 0.7551\n",
171
+ "Smin = 12.2407\n"
172
+ ]
173
+ }
174
+ ],
175
+ "source": [
176
+ "# %%\n",
177
+ "from tensorflow.keras.models import Sequential\n",
178
+ "from tensorflow.keras.layers import Dense, Dropout\n",
179
+ "from tensorflow.keras.optimizers import Adam\n",
180
+ "from sklearn.model_selection import train_test_split\n",
181
+ "from sklearn.metrics import precision_recall_curve, auc\n",
182
+ "import numpy as np\n",
183
+ "import math\n",
184
+ "\n",
185
+ "# --- Preparar dados para stacking ---\n",
186
+ "# (já com y_pb, y_bfd, y_esm com shape (1724, 597))\n",
187
+ "X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1724, 597*3)\n",
188
+ "y_stack = y_true.copy() # (1724, 597)\n",
189
+ "\n",
190
+ "# --- Divisão treino/validação ---\n",
191
+ "X_train, X_val, y_train, y_val = train_test_split(X_stack, y_stack, test_size=0.3, random_state=42)\n",
192
+ "\n",
193
+ "# --- Modelo MLP (usa GPU automaticamente se disponível) ---\n",
194
+ "from tensorflow.keras.callbacks import EarlyStopping\n",
195
+ "\n",
196
+ "model = Sequential([\n",
197
+ " Dense(512, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
198
+ " Dropout(0.3),\n",
199
+ " Dense(256, activation=\"relu\"),\n",
200
+ " Dropout(0.3),\n",
201
+ " Dense(y_stack.shape[1], activation=\"sigmoid\")\n",
202
+ "])\n",
203
+ "\n",
204
+ "model.compile(optimizer=Adam(1e-3), loss=\"binary_crossentropy\")\n",
205
+ "\n",
206
+ "model.fit(X_train, y_train, validation_data=(X_val, y_val),\n",
207
+ " epochs=50, batch_size=64, verbose=1,\n",
208
+ " callbacks=[EarlyStopping(patience=5, restore_best_weights=True)])\n",
209
+ "\n",
210
+ "# --- Prever com stacking ---\n",
211
+ "y_pred_stack = model.predict(X_stack, batch_size=64)\n",
212
+ "\n",
213
+ "# --- Métricas ---\n",
214
+ "THR = np.linspace(0, 1, 101)\n",
215
+ "def fmax(y_t, y_p):\n",
216
+ " best, thr = 0, 0\n",
217
+ " for t in THR:\n",
218
+ " y_b = (y_p >= t).astype(int)\n",
219
+ " tp = (y_t * y_b).sum(1); fp = ((1 - y_t) * y_b).sum(1); fn = (y_t * (1 - y_b)).sum(1)\n",
220
+ " f1 = 2 * tp / (2 * tp + fp + fn + 1e-8); m = f1.mean()\n",
221
+ " if m > best: best, thr = m, t\n",
222
+ " return best, thr\n",
223
+ "\n",
224
+ "def auprc(y_t, y_p):\n",
225
+ " p, r, _ = precision_recall_curve(y_t.ravel(), y_p.ravel())\n",
226
+ " return auc(r, p)\n",
227
+ "\n",
228
+ "def smin(y_t, y_p, thr, alpha=0.5):\n",
229
+ " y_b = (y_p >= thr).astype(int)\n",
230
+ " ic = -(np.log((y_t + y_b).sum(0) + 1e-8) - np.log((y_t + y_b).sum() + 1e-8))\n",
231
+ " ru = np.logical_and(y_b, np.logical_not(y_t)) * ic\n",
232
+ " mi = np.logical_and(y_t, np.logical_not(y_b)) * ic\n",
233
+ " return np.sqrt((alpha * ru.sum(1))**2 + ((1 - alpha) * mi.sum(1))**2).mean()\n",
234
+ "\n",
235
+ "f, thr = fmax(y_stack, y_pred_stack)\n",
236
+ "print(f\"\\n STACKING (GPU-Keras MLP)\")\n",
237
+ "print(f\"Fmax = {f:.4f}\")\n",
238
+ "print(f\"Thr. = {thr:.2f}\")\n",
239
+ "print(f\"AuPRC = {auprc(y_stack, y_pred_stack):.4f}\")\n",
240
+ "print(f\"Smin = {smin(y_stack, y_pred_stack, thr):.4f}\")\n"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 10,
246
+ "id": "00695029-3d24-4803-a6e1-8ac5fd70b710",
247
+ "metadata": {},
248
+ "outputs": [
249
+ {
250
+ "name": "stdout",
251
+ "output_type": "stream",
252
+ "text": [
253
+ "guardado em models/modelo_ensemble_stacking.keras\n"
254
+ ]
255
+ }
256
+ ],
257
+ "source": [
258
+ "model.save(\"models/modelo_ensemble_stacking.keras\")\n",
259
+ "print('guardado em models/modelo_ensemble_stacking.keras')"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "id": "37629e3a-1c24-4f0f-9d12-dddf48be8724",
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": []
269
+ }
270
+ ],
271
+ "metadata": {
272
+ "kernelspec": {
273
+ "display_name": "Python 3 (ipykernel)",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.8.18"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 5
292
+ }
notebooks/Input.ipynb ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9eca7d69-3f17-4306-84d0-58a0363144fa",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "A gerar embeddings por chunks...\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
21
+ " warnings.warn(\n",
22
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
23
+ " warnings.warn(\n",
24
+ "Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
25
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
26
+ ]
27
+ },
28
+ {
29
+ "name": "stdout",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "A fazer predições base...\n",
33
+ "\n",
34
+ " GO terms com prob ≥ 0.5:\n",
35
+ "('GO:0003674', 'GO:0003824', 'GO:0005488', 'GO:0016491', 'GO:0036094', 'GO:0043167')\n",
36
+ "\n",
37
+ " Top 10 GO terms mais prováveis:\n",
38
+ "GO:0003674 : 0.9975\n",
39
+ "GO:0003824 : 0.9156\n",
40
+ "GO:0036094 : 0.6652\n",
41
+ "GO:0043167 : 0.6336\n",
42
+ "GO:0016491 : 0.6327\n",
43
+ "GO:0005488 : 0.5595\n",
44
+ "GO:0043169 : 0.4801\n",
45
+ "GO:0140096 : 0.4790\n",
46
+ "GO:0051213 : 0.4551\n",
47
+ "GO:0046872 : 0.4098\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "# %%\n",
53
+ "import numpy as np\n",
54
+ "import torch\n",
55
+ "from transformers import AutoTokenizer, AutoModel\n",
56
+ "from tensorflow.keras.models import load_model\n",
57
+ "import joblib\n",
58
+ "\n",
59
+ "# --- Parâmetros ---\n",
60
+ "SEQ_FASTA = \"MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIEDARKEREAAAAAAEAAEATEQIVFEEEDGKALLNLFFTLRSSKTPALSRSLKVFETFEAKIHHLETRPCRKPRDSLEGLEYFVRCEVHLSDVSTLISSIKRIAEDVKTTKEVKFHWFPKKISELDRCHHLITKFDPDLDQEHPGFTDPVYRQRRKMIGDIAFRYKQGEPIPRVEYTEEEIGTWREVYSTLRDLYTTHACSEHLEAFNLLERHCGYSPENIPQLEDVSRFLRERTGFQLRPVAGLLSARDFLASLAFRVFQCTQYIRHASSPMHSPEPDCVHELLGHVPILADRVFAQFSQNIGLASLGASEEDIEKLSTLYWFTVEFGLCKQGGIVKAYGAGLLSSYGELVHALSDEPERREFDPEAAAIQPYQDQNYQSVYFVSESFTDAKEKLRSYVAGIKRPFSVRFDPYTYSIEVLDNPLKIRGGLESVKDELKMLTDALNVLA\"\n",
61
+ "TOP_N = 10\n",
62
+ "\n",
63
+ "# --- 1. Função para dividir sequência (512 para Protbert e Protbertbfd. 1024 para ESM2) ---\n",
64
+ "def slice_sequence(seq, chunk_size):\n",
65
+ " return [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
66
+ "\n",
67
+ "# --- 2. Função para gerar embeddings médios ---\n",
68
+ "def get_embedding_mean(model_name, seq, chunk_size):\n",
69
+ " tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)\n",
70
+ " model = AutoModel.from_pretrained(model_name)\n",
71
+ " model.eval()\n",
72
+ "\n",
73
+ " chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
74
+ " embeddings = []\n",
75
+ "\n",
76
+ " for chunk in chunks:\n",
77
+ " seq_chunk = \" \".join(list(chunk))\n",
78
+ " # tokenizar SEM truncar\n",
79
+ " inputs = tokenizer(seq_chunk,\n",
80
+ " return_tensors=\"pt\",\n",
81
+ " truncation=False, # ≤ 512 ou 1024 já garantido\n",
82
+ " padding=False)\n",
83
+ " with torch.no_grad():\n",
84
+ " cls = model(**inputs).last_hidden_state[:, 0, :].squeeze().numpy()\n",
85
+ " embeddings.append(cls)\n",
86
+ "\n",
87
+ " return np.mean(embeddings, axis=0, keepdims=True) # (1, dim)\n",
88
+ "\n",
89
+ "print(\"A gerar embeddings por chunks...\")\n",
90
+ "emb_pb = get_embedding_mean(\"Rostlab/prot_bert\", SEQ_FASTA, 512)\n",
91
+ "emb_bfd = get_embedding_mean(\"Rostlab/prot_bert_bfd\", SEQ_FASTA, 512)\n",
92
+ "emb_esm = get_embedding_mean(\"facebook/esm2_t33_650M_UR50D\", SEQ_FASTA, 1024)\n",
93
+ "\n",
94
+ "# --- 3. Carregar os MLPs base ---\n",
95
+ "mlp_pb = load_model(\"models/protbert_mlp.keras\")\n",
96
+ "mlp_bfd = load_model(\"models/protbertbfd_mlp.keras\")\n",
97
+ "mlp_esm = load_model(\"models/esm2_mlp.keras\")\n",
98
+ "\n",
99
+ "# --- 4. Gerar predições base (garantir 597 colunas) ---\n",
100
+ "print(\"A fazer predições base...\")\n",
101
+ "y_pb = mlp_pb.predict(emb_pb)[:, :597]\n",
102
+ "y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]\n",
103
+ "y_esm = mlp_esm.predict(emb_esm)[:, :597]\n",
104
+ "\n",
105
+ "# --- 5. Concatenar para o stacking ---\n",
106
+ "X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)\n",
107
+ "\n",
108
+ "# --- 6. Carregar modelo de stacking ---\n",
109
+ "stacking = load_model(\"models/modelo_ensemble_stacking.keras\")\n",
110
+ "y_pred = stacking.predict(X_stack)\n",
111
+ "\n",
112
+ "# --- 7. Carregar binarizador (597 GO terms) ---\n",
113
+ "mlb = joblib.load(\"data/mlb_597.pkl\")\n",
114
+ "go_terms = mlb.classes_\n",
115
+ "\n",
116
+ "# --- 8. Mostrar resultados ---\n",
117
+ "print(\"\\n GO terms com prob ≥ 0.5:\")\n",
118
+ "predicted_terms = mlb.inverse_transform((y_pred >= 0.5).astype(int))\n",
119
+ "print(predicted_terms[0] if predicted_terms[0] else \"Nenhum GO term acima de 0.5\")\n",
120
+ "\n",
121
+ "print(f\"\\n Top {TOP_N} GO terms mais prováveis:\")\n",
122
+ "top_idx = np.argsort(-y_pred[0])[:TOP_N]\n",
123
+ "for i in top_idx:\n",
124
+ " print(f\"{go_terms[i]} : {y_pred[0][i]:.4f}\")\n"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "id": "e959e7d9-15ba-4533-a2bb-ddd7df2a639d",
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": []
134
+ }
135
+ ],
136
+ "metadata": {
137
+ "kernelspec": {
138
+ "display_name": "Python 3 (ipykernel)",
139
+ "language": "python",
140
+ "name": "python3"
141
+ },
142
+ "language_info": {
143
+ "codemirror_mode": {
144
+ "name": "ipython",
145
+ "version": 3
146
+ },
147
+ "file_extension": ".py",
148
+ "mimetype": "text/x-python",
149
+ "name": "python",
150
+ "nbconvert_exporter": "python",
151
+ "pygments_lexer": "ipython3",
152
+ "version": "3.8.18"
153
+ }
154
+ },
155
+ "nbformat": 4,
156
+ "nbformat_minor": 5
157
+ }
notebooks/PAM1_ESM2.ipynb ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 9,
6
+ "id": "641053e3-7fec-4f9b-a75e-ddd957af03c4",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
+ "✓ Dataset preparado:\n",
15
+ " - Training: (31142, 3)\n",
16
+ " - Validation: (1724, 3)\n",
17
+ " - Test: (1724, 3)\n",
18
+ " - GO terms: 602\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "# %%\n",
24
+ "import pandas as pd\n",
25
+ "import numpy as np\n",
26
+ "from Bio import SeqIO\n",
27
+ "from goatools.obo_parser import GODag\n",
28
+ "from collections import Counter\n",
29
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
30
+ "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
31
+ "import os, random\n",
32
+ "\n",
33
+ "# --- 1. Carregar ficheiros principais ---\n",
34
+ "FASTA = \"uniprot_sprot_exp.fasta\"\n",
35
+ "ANNOT = \"uniprot_sprot_exp.txt\"\n",
36
+ "GO_OBO = \"go.obo\"\n",
37
+ "\n",
38
+ "# --- 2. Ler sequências ---\n",
39
+ "seqs, ids = [], []\n",
40
+ "for record in SeqIO.parse(FASTA, \"fasta\"):\n",
41
+ " ids.append(record.id)\n",
42
+ " seqs.append(str(record.seq))\n",
43
+ "\n",
44
+ "df_seq = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
45
+ "\n",
46
+ "# --- 3. Ler anotações GO:MF ---\n",
47
+ "df_ann = pd.read_csv(ANNOT, sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"category\"])\n",
48
+ "df_ann = df_ann[df_ann[\"category\"] == \"F\"]\n",
49
+ "\n",
50
+ "# --- 4. Propagação hierárquica dos GO terms ---\n",
51
+ "go_dag = GODag(GO_OBO)\n",
52
+ "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
53
+ "\n",
54
+ "def propagate_terms(terms):\n",
55
+ " expanded = set()\n",
56
+ " for t in terms:\n",
57
+ " if t in go_dag:\n",
58
+ " expanded |= go_dag[t].get_all_parents()\n",
59
+ " expanded.add(t)\n",
60
+ " return list(expanded & mf_terms)\n",
61
+ "\n",
62
+ "grouped = df_ann.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
63
+ "grouped[\"go_term\"] = grouped[\"go_term\"].apply(propagate_terms)\n",
64
+ "\n",
65
+ "# --- 5. Juntar com sequência ---\n",
66
+ "df = df_seq.merge(grouped, on=\"protein_id\")\n",
67
+ "df = df[df[\"go_term\"].str.len() > 0]\n",
68
+ "\n",
69
+ "# --- 6. Filtrar GO terms com ≥50 proteínas ---\n",
70
+ "all_terms = [term for sublist in df[\"go_term\"] for term in sublist]\n",
71
+ "term_counts = Counter(all_terms)\n",
72
+ "valid_terms = {t for t, count in term_counts.items() if count >= 50}\n",
73
+ "\n",
74
+ "df[\"go_term\"] = df[\"go_term\"].apply(lambda ts: [t for t in ts if t in valid_terms])\n",
75
+ "df = df[df[\"go_term\"].str.len() > 0]\n",
76
+ "\n",
77
+ "# --- 7. Preparar labels e dividir por proteína ---\n",
78
+ "df[\"go_terms\"] = df[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
79
+ "df = df[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
80
+ "\n",
81
+ "mlb = MultiLabelBinarizer()\n",
82
+ "Y = mlb.fit_transform(df[\"go_terms\"].str.split(\";\"))\n",
83
+ "X = df[[\"protein_id\", \"sequence\"]].values\n",
84
+ "\n",
85
+ "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
86
+ "train_idx, temp_idx = next(mskf.split(X, Y))\n",
87
+ "val_idx, test_idx = np.array_split(temp_idx, 2)\n",
88
+ "\n",
89
+ "df_train = df.iloc[train_idx].copy()\n",
90
+ "df_val = df.iloc[val_idx].copy()\n",
91
+ "df_test = df.iloc[test_idx].copy()\n",
92
+ "\n",
93
+ "os.makedirs(\"data\", exist_ok=True)\n",
94
+ "df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
95
+ "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
96
+ "df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
97
+ "\n",
98
+ "# --- 8. Guardar o binarizador ---\n",
99
+ "import joblib\n",
100
+ "joblib.dump(mlb, \"data/mlb.pkl\")\n",
101
+ "\n",
102
+ "print(\"✓ Dataset preparado:\")\n",
103
+ "print(\" - Training:\", df_train.shape)\n",
104
+ "print(\" - Validation:\", df_val.shape)\n",
105
+ "print(\" - Test:\", df_test.shape)\n",
106
+ "print(\" - GO terms:\", len(mlb.classes_))\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 10,
112
+ "id": "40ba1798-daf8-4649-ae3f-bfe81df6437f",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "# %%\n",
117
+ "import random\n",
118
+ "from collections import defaultdict\n",
119
+ "\n",
120
+ "# --- PAM1 matrix normalizada ---\n",
121
+ "pam_data = {\n",
122
+ " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
123
+ " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
124
+ " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
125
+ " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
126
+ " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
127
+ " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
128
+ " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
129
+ " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
130
+ " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
131
+ " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
132
+ " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
133
+ " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
134
+ " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
135
+ " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
136
+ " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
137
+ " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
138
+ " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
139
+ " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
140
+ " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
141
+ " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
142
+ "}\n",
143
+ "\n",
144
+ "pam_raw = pd.DataFrame(pam_data, index=pam_data.keys())\n",
145
+ "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
146
+ "pam_dict = {aa: pam_matrix.loc[aa].to_dict() for aa in pam_matrix.index}\n",
147
+ "\n",
148
+ "def pam1_substitution(aa):\n",
149
+ " if aa not in pam_dict:\n",
150
+ " return aa\n",
151
+ " subs = list(pam_dict[aa].keys())\n",
152
+ " probs = list(pam_dict[aa].values())\n",
153
+ " return np.random.choice(subs, p=probs)\n",
154
+ "\n",
155
+ "def augment_sequence(seq, sub_prob=0.05):\n",
156
+ " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
157
+ "\n",
158
+ "def slice_sequence(seq, win=1024):\n",
159
+ " if len(seq) <= win:\n",
160
+ " return [seq]\n",
161
+ " return [seq[i:i+win] for i in range(0, len(seq), win)]\n",
162
+ "\n",
163
+ "def format_seq(seq):\n",
164
+ " return \" \".join(seq)\n",
165
+ "\n",
166
+ "# --- Carregar labels e datasets ---\n",
167
+ "import joblib\n",
168
+ "mlb = joblib.load(\"data/mlb.pkl\")\n",
169
+ "df_train = pd.read_csv(\"data/mf-training.csv\")\n",
170
+ "df_val = pd.read_csv(\"data/mf-validation.csv\")\n",
171
+ "df_test = pd.read_csv(\"data/mf-test.csv\")\n",
172
+ "\n",
173
+ "# --- Slicing + augmentação no treino ---\n",
174
+ "X_train, y_train = [], []\n",
175
+ "\n",
176
+ "for _, row in df_train.iterrows():\n",
177
+ " seq_aug = augment_sequence(row[\"sequence\"], sub_prob=0.05)\n",
178
+ " slices = slice_sequence(seq_aug, win=1024)\n",
179
+ " label = mlb.transform([row[\"go_terms\"].split(\";\")])[0]\n",
180
+ " for sl in slices:\n",
181
+ " X_train.append(format_seq(sl))\n",
182
+ " y_train.append(label)\n",
183
+ "\n",
184
+ "# --- Sem slicing no val/test ---\n",
185
+ "X_val = [format_seq(seq) for seq in df_val[\"sequence\"]]\n",
186
+ "X_test = [format_seq(seq) for seq in df_test[\"sequence\"]]\n",
187
+ "\n",
188
+ "y_val = mlb.transform(df_val[\"go_terms\"].str.split(\";\"))\n",
189
+ "y_test = mlb.transform(df_test[\"go_terms\"].str.split(\";\"))\n",
190
+ "\n",
191
+ "np.save(\"embeddings/y_test.npy\", y_test)"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 11,
197
+ "id": "80d5c1fb-9c84-463d-8d8c-bfcc2982afc9",
198
+ "metadata": {},
199
+ "outputs": [
200
+ {
201
+ "name": "stderr",
202
+ "output_type": "stream",
203
+ "text": [
204
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
205
+ " warnings.warn(\n",
206
+ "Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
207
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
208
+ "100%|██████████| 2189/2189 [1:17:26<00:00, 2.12s/it]\n",
209
+ "100%|██████████| 108/108 [03:43<00:00, 2.07s/it]\n",
210
+ "100%|██████████| 108/108 [03:56<00:00, 2.19s/it]\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "# %%\n",
216
+ "from transformers import AutoTokenizer, AutoModel\n",
217
+ "import torch\n",
218
+ "from tqdm import tqdm\n",
219
+ "import numpy as np\n",
220
+ "import os\n",
221
+ "\n",
222
+ "# --- Configurações ---\n",
223
+ "MODEL_NAME = \"facebook/esm2_t33_650M_UR50D\"\n",
224
+ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
225
+ "CHUNK_SIZE = 16\n",
226
+ "\n",
227
+ "# --- Carregar modelo ---\n",
228
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
229
+ "model = AutoModel.from_pretrained(MODEL_NAME)\n",
230
+ "model.to(DEVICE)\n",
231
+ "model.eval()\n",
232
+ "\n",
233
+ "def extract_embeddings(texts):\n",
234
+ " embeddings = []\n",
235
+ " for i in tqdm(range(0, len(texts), CHUNK_SIZE)):\n",
236
+ " batch = texts[i:i+CHUNK_SIZE]\n",
237
+ " with torch.no_grad():\n",
238
+ " inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=1024)\n",
239
+ " inputs = {k: v.to(DEVICE) for k, v in inputs.items()}\n",
240
+ " outputs = model(**inputs).last_hidden_state\n",
241
+ " cls_tokens = outputs[:, 0, :] # token CLS\n",
242
+ " embeddings.append(cls_tokens.cpu().numpy())\n",
243
+ " return np.vstack(embeddings)\n",
244
+ "\n",
245
+ "# --- Extrair e guardar embeddings ---\n",
246
+ "os.makedirs(\"embeddings\", exist_ok=True)\n",
247
+ "\n",
248
+ "emb_train = extract_embeddings(X_train)\n",
249
+ "emb_val = extract_embeddings(X_val)\n",
250
+ "emb_test = extract_embeddings(X_test)\n",
251
+ "\n",
252
+ "np.save(\"embeddings/esm2_train.npy\", emb_train)\n",
253
+ "np.save(\"embeddings/esm2_val.npy\", emb_val)\n",
254
+ "np.save(\"embeddings/esm2_test.npy\", emb_test)\n",
255
+ "\n",
256
+ "np.save(\"embeddings/y_train.npy\", np.array(y_train))\n",
257
+ "np.save(\"embeddings/y_val.npy\", np.array(y_val))\n"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": 1,
263
+ "id": "592e4f6c-b871-4f0b-b84c-f3918c698544",
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "Epoch 1/100\n",
271
+ "1095/1095 [==============================] - 4s 2ms/step - loss: 0.0552 - val_loss: 0.0455\n",
272
+ "Epoch 2/100\n",
273
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0445 - val_loss: 0.0424\n",
274
+ "Epoch 3/100\n",
275
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0419 - val_loss: 0.0394\n",
276
+ "Epoch 4/100\n",
277
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0403 - val_loss: 0.0381\n",
278
+ "Epoch 5/100\n",
279
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0392 - val_loss: 0.0373\n",
280
+ "Epoch 6/100\n",
281
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0383 - val_loss: 0.0362\n",
282
+ "Epoch 7/100\n",
283
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0374 - val_loss: 0.0358\n",
284
+ "Epoch 8/100\n",
285
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0368 - val_loss: 0.0351\n",
286
+ "Epoch 9/100\n",
287
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0362 - val_loss: 0.0348\n",
288
+ "Epoch 10/100\n",
289
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0357 - val_loss: 0.0344\n",
290
+ "Epoch 11/100\n",
291
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0353 - val_loss: 0.0340\n",
292
+ "Epoch 12/100\n",
293
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0349 - val_loss: 0.0335\n",
294
+ "Epoch 13/100\n",
295
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0344 - val_loss: 0.0334\n",
296
+ "Epoch 14/100\n",
297
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0342 - val_loss: 0.0330\n",
298
+ "Epoch 15/100\n",
299
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0338 - val_loss: 0.0325\n",
300
+ "Epoch 16/100\n",
301
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0337 - val_loss: 0.0327\n",
302
+ "Epoch 17/100\n",
303
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0333 - val_loss: 0.0325\n",
304
+ "Epoch 18/100\n",
305
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0330 - val_loss: 0.0322\n",
306
+ "Epoch 19/100\n",
307
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0328 - val_loss: 0.0321\n",
308
+ "Epoch 20/100\n",
309
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0326 - val_loss: 0.0322\n",
310
+ "Epoch 21/100\n",
311
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0323 - val_loss: 0.0320\n",
312
+ "Epoch 22/100\n",
313
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0322 - val_loss: 0.0317\n",
314
+ "Epoch 23/100\n",
315
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0320 - val_loss: 0.0318\n",
316
+ "Epoch 24/100\n",
317
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0317 - val_loss: 0.0315\n",
318
+ "Epoch 25/100\n",
319
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0316 - val_loss: 0.0317\n",
320
+ "Epoch 26/100\n",
321
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0314 - val_loss: 0.0313\n",
322
+ "Epoch 27/100\n",
323
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0313 - val_loss: 0.0320\n",
324
+ "Epoch 28/100\n",
325
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0311 - val_loss: 0.0315\n",
326
+ "Epoch 29/100\n",
327
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0310 - val_loss: 0.0313\n",
328
+ "Epoch 30/100\n",
329
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0309 - val_loss: 0.0313\n",
330
+ "Epoch 31/100\n",
331
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0307 - val_loss: 0.0310\n",
332
+ "Epoch 32/100\n",
333
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0306 - val_loss: 0.0310\n",
334
+ "Epoch 33/100\n",
335
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0304 - val_loss: 0.0310\n",
336
+ "Epoch 34/100\n",
337
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0303 - val_loss: 0.0312\n",
338
+ "Epoch 35/100\n",
339
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0302 - val_loss: 0.0309\n",
340
+ "Epoch 36/100\n",
341
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0300 - val_loss: 0.0310\n",
342
+ "Epoch 37/100\n",
343
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0299 - val_loss: 0.0313\n",
344
+ "Epoch 38/100\n",
345
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0298 - val_loss: 0.0312\n",
346
+ "Epoch 39/100\n",
347
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0296 - val_loss: 0.0307\n",
348
+ "Epoch 40/100\n",
349
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0296 - val_loss: 0.0306\n",
350
+ "Epoch 41/100\n",
351
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0295 - val_loss: 0.0310\n",
352
+ "Epoch 42/100\n",
353
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0294 - val_loss: 0.0304\n",
354
+ "Epoch 43/100\n",
355
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0294 - val_loss: 0.0308\n",
356
+ "Epoch 44/100\n",
357
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0293 - val_loss: 0.0306\n",
358
+ "Epoch 45/100\n",
359
+ "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0292 - val_loss: 0.0307\n",
360
+ "Epoch 46/100\n",
361
+ "1095/1095 [==============================] - 4s 4ms/step - loss: 0.0290 - val_loss: 0.0305\n",
362
+ "Epoch 47/100\n",
363
+ "1095/1095 [==============================] - 4s 4ms/step - loss: 0.0290 - val_loss: 0.0305\n",
364
+ "Modelo guardado em models/esm2_mlp.keras\n",
365
+ " Predições do ESM-2 salvas com forma: (1724, 602)\n"
366
+ ]
367
+ }
368
+ ],
369
+ "source": [
370
+ "# %%\n",
371
+ "import numpy as np\n",
372
+ "import tensorflow as tf\n",
373
+ "from tensorflow.keras.models import Sequential\n",
374
+ "from tensorflow.keras.layers import Dense, Dropout\n",
375
+ "from tensorflow.keras.callbacks import EarlyStopping\n",
376
+ "from sklearn.metrics import average_precision_score\n",
377
+ "\n",
378
+ "# --- Carregar os embeddings e labels ---\n",
379
+ "X_train = np.load(\"embeddings/esm2_train.npy\")\n",
380
+ "X_val = np.load(\"embeddings/esm2_val.npy\")\n",
381
+ "X_test = np.load(\"embeddings/esm2_test.npy\")\n",
382
+ "\n",
383
+ "y_train = np.load(\"embeddings/y_train.npy\")\n",
384
+ "y_val = np.load(\"embeddings/y_val.npy\")\n",
385
+ "y_test = np.load(\"embeddings/y_test.npy\")\n",
386
+ "\n",
387
+ "# --- Definir o modelo ---\n",
388
+ "model = Sequential([\n",
389
+ " Dense(1024, activation='relu', input_shape=(X_train.shape[1],)),\n",
390
+ " Dropout(0.3),\n",
391
+ " Dense(512, activation='relu'),\n",
392
+ " Dropout(0.3),\n",
393
+ " Dense(y_train.shape[1], activation='sigmoid')\n",
394
+ "])\n",
395
+ "\n",
396
+ "model.compile(optimizer='adam', loss='binary_crossentropy')\n",
397
+ "\n",
398
+ "# --- Treinar ---\n",
399
+ "early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
400
+ "\n",
401
+ "history = model.fit(\n",
402
+ " X_train, y_train,\n",
403
+ " validation_data=(X_val, y_val),\n",
404
+ " epochs=100,\n",
405
+ " batch_size=32,\n",
406
+ " callbacks=[early_stop],\n",
407
+ " verbose=1\n",
408
+ ")\n",
409
+ "\n",
410
+ "# --- Salvar o modelo ---\n",
411
+ "model.save(\"models/esm2_mlp.keras\")\n",
412
+ "print(\"Modelo guardado em models/esm2_mlp.keras\")\n",
413
+ "\n",
414
+ "# --- Fazer predições no conjunto de teste ---\n",
415
+ "y_prob = model.predict(X_test)\n",
416
+ "np.save(\"predictions/mf-esm2.npy\", y_prob)\n",
417
+ "\n",
418
+ "print(\" Predições do ESM-2 salvas com forma:\", y_prob.shape)\n"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": 15,
424
+ "id": "3dddb0df-3ea5-4e32-8cf0-45e90be8ba66",
425
+ "metadata": {},
426
+ "outputs": [
427
+ {
428
+ "name": "stdout",
429
+ "output_type": "stream",
430
+ "text": [
431
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
432
+ "✓ Dados carregados: (1724, 602) proteínas × 602 GO terms\n",
433
+ "\n",
434
+ " Resultados finais (ESM-2 + PAM1 + propagação):\n",
435
+ "Fmax = 0.6439\n",
436
+ "Thr. = 0.34\n",
437
+ "AuPRC = 0.6948\n",
438
+ "Smin = 14.1500\n"
439
+ ]
440
+ }
441
+ ],
442
+ "source": [
443
+ "# %%\n",
444
+ "import numpy as np\n",
445
+ "import joblib\n",
446
+ "import math\n",
447
+ "from goatools.obo_parser import GODag\n",
448
+ "from sklearn.metrics import precision_recall_curve, auc\n",
449
+ "\n",
450
+ "# --- 1. Carregar dados e parâmetros ---\n",
451
+ "GO_FILE = \"go.obo\"\n",
452
+ "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
453
+ "ALPHA = 0.5\n",
454
+ "\n",
455
+ "mlb = joblib.load(\"data/mlb.pkl\")\n",
456
+ "y_true = np.load(\"embeddings/y_test.npy\")\n",
457
+ "y_prob = np.load(\"predictions/mf-esm2.npy\")\n",
458
+ "terms = mlb.classes_\n",
459
+ "go_dag = GODag(GO_FILE)\n",
460
+ "\n",
461
+ "print(f\"✓ Dados carregados: {y_true.shape} proteínas × {len(terms)} GO terms\")\n",
462
+ "\n",
463
+ "# --- 2. Fmax ---\n",
464
+ "def compute_fmax(y_true, y_prob, thresholds):\n",
465
+ " fmax, best_thr = 0, 0\n",
466
+ " for t in thresholds:\n",
467
+ " y_pred = (y_prob >= t).astype(int)\n",
468
+ " tp = (y_true * y_pred).sum(axis=1)\n",
469
+ " fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
470
+ " fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
471
+ " precision = tp / (tp + fp + 1e-8)\n",
472
+ " recall = tp / (tp + fn + 1e-8)\n",
473
+ " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
474
+ " avg_f1 = np.mean(f1)\n",
475
+ " if avg_f1 > fmax:\n",
476
+ " fmax, best_thr = avg_f1, t\n",
477
+ " return fmax, best_thr\n",
478
+ "\n",
479
+ "# --- 3. AuPRC (micro) ---\n",
480
+ "def compute_auprc(y_true, y_prob):\n",
481
+ " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
482
+ " return auc(recall, precision)\n",
483
+ "\n",
484
+ "# --- 4. Smin ---\n",
485
+ "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
486
+ " y_pred = (y_prob >= threshold).astype(int)\n",
487
+ "\n",
488
+ " # Informação semântica: IC (Information Content)\n",
489
+ " ic = {}\n",
490
+ " total = (y_true + y_pred).sum(axis=0).sum()\n",
491
+ " for i, term in enumerate(terms):\n",
492
+ " freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
493
+ " ic[term] = -np.log((freq + 1e-8) / total)\n",
494
+ "\n",
495
+ " # Para cada proteína, calcular RU e MI\n",
496
+ " s_values = []\n",
497
+ " for true_vec, pred_vec in zip(y_true, y_pred):\n",
498
+ " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
499
+ " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
500
+ "\n",
501
+ " anc_true = set()\n",
502
+ " for t in true_terms:\n",
503
+ " if t in go_dag:\n",
504
+ " anc_true |= go_dag[t].get_all_parents()\n",
505
+ " anc_pred = set()\n",
506
+ " for t in pred_terms:\n",
507
+ " if t in go_dag:\n",
508
+ " anc_pred |= go_dag[t].get_all_parents()\n",
509
+ "\n",
510
+ " ru = pred_terms - true_terms\n",
511
+ " mi = true_terms - pred_terms\n",
512
+ " dist_ru = sum(ic.get(t, 0) for t in ru)\n",
513
+ " dist_mi = sum(ic.get(t, 0) for t in mi)\n",
514
+ " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
515
+ " s_values.append(s)\n",
516
+ "\n",
517
+ " return np.mean(s_values)\n",
518
+ "\n",
519
+ "# --- 5. Avaliação ---\n",
520
+ "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
521
+ "auprc = compute_auprc(y_true, y_prob)\n",
522
+ "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
523
+ "\n",
524
+ "print(f\"\\n Resultados finais (ESM-2 + PAM1 + propagação):\")\n",
525
+ "print(f\"Fmax = {fmax:.4f}\")\n",
526
+ "print(f\"Thr. = {thr:.2f}\")\n",
527
+ "print(f\"AuPRC = {auprc:.4f}\")\n",
528
+ "print(f\"Smin = {smin:.4f}\")\n"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": null,
534
+ "id": "1a1ea084-01de-4dc4-88da-e7ffeb8c94c9",
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": []
538
+ }
539
+ ],
540
+ "metadata": {
541
+ "kernelspec": {
542
+ "display_name": "Python 3 (ipykernel)",
543
+ "language": "python",
544
+ "name": "python3"
545
+ },
546
+ "language_info": {
547
+ "codemirror_mode": {
548
+ "name": "ipython",
549
+ "version": 3
550
+ },
551
+ "file_extension": ".py",
552
+ "mimetype": "text/x-python",
553
+ "name": "python",
554
+ "nbconvert_exporter": "python",
555
+ "pygments_lexer": "ipython3",
556
+ "version": "3.8.18"
557
+ }
558
+ },
559
+ "nbformat": 4,
560
+ "nbformat_minor": 5
561
+ }
notebooks/PAM1_protbert.ipynb ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
+ "✓ Ficheiros criados:\n",
15
+ " - data/mf-training.csv : (31142, 3)\n",
16
+ " - data/mf-validation.csv: (1724, 3)\n",
17
+ " - data/mf-test.csv : (1724, 3)\n",
18
+ "GO terms únicos (após propagação e filtro): 602\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "import pandas as pd\n",
24
+ "from Bio import SeqIO\n",
25
+ "from collections import Counter\n",
26
+ "from goatools.obo_parser import GODag\n",
27
+ "from sklearn.model_selection import train_test_split\n",
28
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
29
+ "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
30
+ "import numpy as np\n",
31
+ "import os\n",
32
+ "\n",
33
+ "# --- 1. Carregar GO anotações ------------------------------------------\n",
34
+ "annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
35
+ "annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
36
+ "\n",
37
+ "# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
38
+ "# propagação hierárquica\n",
39
+ "# https://geneontology.org/docs/download-ontology/\n",
40
+ "go_dag = GODag(\"go.obo\")\n",
41
+ "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
42
+ "\n",
43
+ "def propagate_terms(term_list):\n",
44
+ " full = set()\n",
45
+ " for t in term_list:\n",
46
+ " if t not in go_dag:\n",
47
+ " continue\n",
48
+ " full.add(t)\n",
49
+ " full.update(go_dag[t].get_all_parents())\n",
50
+ " return list(full & mf_terms)\n",
51
+ "\n",
52
+ "# --- 3. Carregar sequências --------------------------------------------\n",
53
+ "seqs, ids = [], []\n",
54
+ "for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
55
+ " ids.append(record.id)\n",
56
+ " seqs.append(str(record.seq))\n",
57
+ "\n",
58
+ "seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
59
+ "\n",
60
+ "# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
61
+ "grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
62
+ "data = seq_df.merge(grouped, on=\"protein_id\")\n",
63
+ "data = data[data[\"go_term\"].apply(len) > 0]\n",
64
+ "data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
65
+ "data = data[data[\"go_term\"].apply(len) > 0]\n",
66
+ "\n",
67
+ "# --- 5. Filtrar GO terms raros -----------------------------------------\n",
68
+ "# todos os terms com menos de 50 proteinas associadas\n",
69
+ "all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
70
+ "term_counts = Counter(all_terms)\n",
71
+ "valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
72
+ "data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
73
+ "data = data[data[\"go_term\"].apply(len) > 0]\n",
74
+ "\n",
75
+ "# --- 6. Preparar dataset final -----------------------------------------\n",
76
+ "data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
77
+ "data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
78
+ "\n",
79
+ "# --- 7. Binarizar labels e dividir -------------------------------------\n",
80
+ "mlb = MultiLabelBinarizer()\n",
81
+ "Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
82
+ "X = data[[\"protein_id\", \"sequence\"]].values\n",
83
+ "\n",
84
+ "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
85
+ "train_idx, temp_idx = next(mskf.split(X, Y))\n",
86
+ "val_idx, test_idx = np.array_split(temp_idx, 2)\n",
87
+ "\n",
88
+ "df_train = data.iloc[train_idx].copy()\n",
89
+ "df_val = data.iloc[val_idx].copy()\n",
90
+ "df_test = data.iloc[test_idx].copy()\n",
91
+ "\n",
92
+ "# --- 8. Guardar em CSV -------------------------------------------------\n",
93
+ "os.makedirs(\"data\", exist_ok=True)\n",
94
+ "df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
95
+ "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
96
+ "df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
97
+ "\n",
98
+ "# --- 9. Confirmar ------------------------------------------------------\n",
99
+ "print(\"✓ Ficheiros criados:\")\n",
100
+ "print(\" - data/mf-training.csv :\", df_train.shape)\n",
101
+ "print(\" - data/mf-validation.csv:\", df_val.shape)\n",
102
+ "print(\" - data/mf-test.csv :\", df_test.shape)\n",
103
+ "print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 2,
109
+ "id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "name": "stderr",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
117
+ " from .autonotebook import tqdm as notebook_tqdm\n",
118
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
119
+ " _torch_pytree._register_pytree_node(\n",
120
+ "100%|██████████| 31142/31142 [00:24<00:00, 1262.18it/s]\n",
121
+ "100%|██████████| 1724/1724 [00:00<00:00, 2628.24it/s]\n",
122
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
123
+ " warnings.warn(\n"
124
+ ]
125
+ },
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "preprocessing train...\n",
131
+ "language: de\n",
132
+ "train sequence lengths:\n",
133
+ "\tmean : 423\n",
134
+ "\t95percentile : 604\n",
135
+ "\t99percentile : 715\n"
136
+ ]
137
+ },
138
+ {
139
+ "data": {
140
+ "text/html": [
141
+ "\n",
142
+ "<style>\n",
143
+ " /* Turns off some styling */\n",
144
+ " progress {\n",
145
+ " /* gets rid of default border in Firefox and Opera. */\n",
146
+ " border: none;\n",
147
+ " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
148
+ " background-size: auto;\n",
149
+ " }\n",
150
+ " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
151
+ " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
152
+ " }\n",
153
+ " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
154
+ " background: #F44336;\n",
155
+ " }\n",
156
+ "</style>\n"
157
+ ],
158
+ "text/plain": [
159
+ "<IPython.core.display.HTML object>"
160
+ ]
161
+ },
162
+ "metadata": {},
163
+ "output_type": "display_data"
164
+ },
165
+ {
166
+ "data": {
167
+ "text/html": [],
168
+ "text/plain": [
169
+ "<IPython.core.display.HTML object>"
170
+ ]
171
+ },
172
+ "metadata": {},
173
+ "output_type": "display_data"
174
+ },
175
+ {
176
+ "name": "stdout",
177
+ "output_type": "stream",
178
+ "text": [
179
+ "Is Multi-Label? True\n",
180
+ "preprocessing test...\n",
181
+ "language: de\n",
182
+ "test sequence lengths:\n",
183
+ "\tmean : 408\n",
184
+ "\t95percentile : 603\n",
185
+ "\t99percentile : 714\n"
186
+ ]
187
+ },
188
+ {
189
+ "data": {
190
+ "text/html": [
191
+ "\n",
192
+ "<style>\n",
193
+ " /* Turns off some styling */\n",
194
+ " progress {\n",
195
+ " /* gets rid of default border in Firefox and Opera. */\n",
196
+ " border: none;\n",
197
+ " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
198
+ " background-size: auto;\n",
199
+ " }\n",
200
+ " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
201
+ " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
202
+ " }\n",
203
+ " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
204
+ " background: #F44336;\n",
205
+ " }\n",
206
+ "</style>\n"
207
+ ],
208
+ "text/plain": [
209
+ "<IPython.core.display.HTML object>"
210
+ ]
211
+ },
212
+ "metadata": {},
213
+ "output_type": "display_data"
214
+ },
215
+ {
216
+ "data": {
217
+ "text/html": [],
218
+ "text/plain": [
219
+ "<IPython.core.display.HTML object>"
220
+ ]
221
+ },
222
+ "metadata": {},
223
+ "output_type": "display_data"
224
+ },
225
+ {
226
+ "name": "stderr",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:1093: UserWarning: Could not load a Tensorflow version of model. (If this worked before, it might be an out-of-memory issue.) Attempting to download/load PyTorch version as TensorFlow model using from_pt=True. You will need PyTorch installed for this.\n",
230
+ " warnings.warn(\n"
231
+ ]
232
+ },
233
+ {
234
+ "name": "stdout",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "\n",
238
+ "\n",
239
+ "begin training using triangular learning rate policy with max lr of 1e-05...\n",
240
+ "Epoch 1/10\n",
241
+ "40995/40995 [==============================] - 13053s 318ms/step - loss: 0.0745 - binary_accuracy: 0.9866 - val_loss: 0.0582 - val_binary_accuracy: 0.9859\n",
242
+ "Epoch 2/10\n",
243
+ "40995/40995 [==============================] - 14484s 353ms/step - loss: 0.0504 - binary_accuracy: 0.9873 - val_loss: 0.0499 - val_binary_accuracy: 0.9867\n",
244
+ "Epoch 3/10\n",
245
+ "40995/40995 [==============================] - 14472s 353ms/step - loss: 0.0450 - binary_accuracy: 0.9879 - val_loss: 0.0449 - val_binary_accuracy: 0.9873\n",
246
+ "Epoch 4/10\n",
247
+ "40995/40995 [==============================] - 14445s 352ms/step - loss: 0.0407 - binary_accuracy: 0.9884 - val_loss: 0.0413 - val_binary_accuracy: 0.9878\n",
248
+ "Epoch 5/10\n",
249
+ "40995/40995 [==============================] - 12524s 305ms/step - loss: 0.0378 - binary_accuracy: 0.9888 - val_loss: 0.0394 - val_binary_accuracy: 0.9881\n",
250
+ "Epoch 6/10\n",
251
+ "40995/40995 [==============================] - 14737s 359ms/step - loss: 0.0359 - binary_accuracy: 0.9891 - val_loss: 0.0383 - val_binary_accuracy: 0.9883\n",
252
+ "Epoch 7/10\n",
253
+ "40995/40995 [==============================] - 20317s 495ms/step - loss: 0.0343 - binary_accuracy: 0.9894 - val_loss: 0.0371 - val_binary_accuracy: 0.9885\n",
254
+ "Epoch 8/10\n",
255
+ "40995/40995 [==============================] - 9073s 221ms/step - loss: 0.0331 - binary_accuracy: 0.9896 - val_loss: 0.0364 - val_binary_accuracy: 0.9887\n",
256
+ "Epoch 9/10\n",
257
+ "40995/40995 [==============================] - 9001s 219ms/step - loss: 0.0320 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
258
+ "Epoch 10/10\n",
259
+ "40995/40995 [==============================] - 8980s 219ms/step - loss: 0.0311 - binary_accuracy: 0.9900 - val_loss: 0.0356 - val_binary_accuracy: 0.9890\n"
260
+ ]
261
+ },
262
+ {
263
+ "ename": "RuntimeError",
264
+ "evalue": "Can't decrement id ref count (unable to extend file properly)",
265
+ "output_type": "error",
266
+ "traceback": [
267
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
268
+ "\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)",
269
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m \u001b[43mhdf5_format\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights_to_hdf5_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
270
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\saving\\hdf5_format.py:646\u001b[0m, in \u001b[0;36msave_weights_to_hdf5_group\u001b[1;34m(f, layers)\u001b[0m\n\u001b[0;32m 645\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 646\u001b[0m param_dset[:] \u001b[38;5;241m=\u001b[39m val\n",
271
+ "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
272
+ "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
273
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\dataset.py:999\u001b[0m, in \u001b[0;36mDataset.__setitem__\u001b[1;34m(self, args, val)\u001b[0m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fspace \u001b[38;5;129;01min\u001b[39;00m selection\u001b[38;5;241m.\u001b[39mbroadcast(mshape):\n\u001b[1;32m--> 999\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdxpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dxpl\u001b[49m\u001b[43m)\u001b[49m\n",
274
+ "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
275
+ "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
276
+ "File \u001b[1;32mh5py\\\\h5d.pyx:282\u001b[0m, in \u001b[0;36mh5py.h5d.DatasetID.write\u001b[1;34m()\u001b[0m\n",
277
+ "File \u001b[1;32mh5py\\\\_proxy.pyx:115\u001b[0m, in \u001b[0;36mh5py._proxy.dset_rw\u001b[1;34m()\u001b[0m\n",
278
+ "\u001b[1;31mOSError\u001b[0m: [Errno 28] Can't write data (file write failed: time = Wed May 7 10:48:36 2025\n, filename = 'mf-fine-tuned-protbert\\weights-10-0.04.hdf5', file descriptor = 4, errno = 28, error message = 'No space left on device', buf = 000002CC552FF040, total write size = 4194304, bytes this sub-write = 4194304, bytes actually written = 18446744073709551615, offset = 1180551864)",
279
+ "\nDuring handling of the above exception, another exception occurred:\n",
280
+ "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
281
+ "Cell \u001b[1;32mIn[2], line 119\u001b[0m\n\u001b[0;32m 113\u001b[0m model \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mget_classifier()\n\u001b[0;32m 114\u001b[0m learner \u001b[38;5;241m=\u001b[39m ktrain\u001b[38;5;241m.\u001b[39mget_learner(model,\n\u001b[0;32m 115\u001b[0m train_data\u001b[38;5;241m=\u001b[39mtrn,\n\u001b[0;32m 116\u001b[0m val_data\u001b[38;5;241m=\u001b[39mval,\n\u001b[0;32m 117\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mBATCH_SIZE)\n\u001b[1;32m--> 119\u001b[0m \u001b[43mlearner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautofit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmf-fine-tuned-protbert\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
282
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1239\u001b[0m, in \u001b[0;36mLearner.autofit\u001b[1;34m(self, lr, epochs, early_stopping, reduce_on_plateau, reduce_factor, cycle_momentum, max_momentum, min_momentum, monitor, checkpoint_folder, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1234\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtriangular learning rate\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1235\u001b[0m U\u001b[38;5;241m.\u001b[39mvprint(\n\u001b[0;32m 1236\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbegin training using \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m policy with max lr of \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (policy, lr),\n\u001b[0;32m 1237\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[0;32m 1238\u001b[0m )\n\u001b[1;32m-> 1239\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1240\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1241\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1242\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mearly_stopping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1243\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1244\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1245\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1248\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1249\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 1250\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
283
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1650\u001b[0m, in \u001b[0;36mGenLearner.fit\u001b[1;34m(self, lr, n_cycles, cycle_len, cycle_mult, lr_decay, checkpoint_folder, early_stopping, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1648\u001b[0m warnings\u001b[38;5;241m.\u001b[39mfilterwarnings(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m\"\u001b[39m, message\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.*Check your callbacks.*\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 1649\u001b[0m fit_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mfit\n\u001b[1;32m-> 1650\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[43mfit_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_data\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1654\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1655\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mworkers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mworkers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_multiprocessing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_multiprocessing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1663\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sgdr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 1664\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m sgdr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
284
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:1230\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m 1227\u001b[0m val_logs \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m name: val \u001b[38;5;28;01mfor\u001b[39;00m name, val \u001b[38;5;129;01min\u001b[39;00m val_logs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m 1228\u001b[0m epoch_logs\u001b[38;5;241m.\u001b[39mupdate(val_logs)\n\u001b[1;32m-> 1230\u001b[0m \u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch_logs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1231\u001b[0m training_logs \u001b[38;5;241m=\u001b[39m epoch_logs\n\u001b[0;32m 1232\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_training:\n",
285
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:413\u001b[0m, in \u001b[0;36mCallbackList.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 411\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_logs(logs)\n\u001b[0;32m 412\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m callback \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallbacks:\n\u001b[1;32m--> 413\u001b[0m \u001b[43mcallback\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
286
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1368\u001b[0m, in \u001b[0;36mModelCheckpoint.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 1366\u001b[0m \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m 1367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_freq \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m-> 1368\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
287
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1431\u001b[0m, in \u001b[0;36mModelCheckpoint._save_model\u001b[1;34m(self, epoch, batch, logs)\u001b[0m\n\u001b[0;32m 1429\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m%05d\u001b[39;00m\u001b[38;5;124m: saving model to \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, filepath))\n\u001b[0;32m 1430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_weights_only:\n\u001b[1;32m-> 1431\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilepath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_options\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1433\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 1434\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39msave(filepath, overwrite\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, options\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_options)\n",
288
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m save_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mh5\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m hdf5_format\u001b[38;5;241m.\u001b[39msave_weights_to_hdf5_group(f, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers)\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 2254\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n",
289
+ "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
290
+ "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
291
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:599\u001b[0m, in \u001b[0;36mFile.__exit__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 596\u001b[0m \u001b[38;5;129m@with_phil\u001b[39m\n\u001b[0;32m 597\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs):\n\u001b[0;32m 598\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid:\n\u001b[1;32m--> 599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
292
+ "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:581\u001b[0m, in \u001b[0;36mFile.close\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 575\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mvalid:\n\u001b[0;32m 576\u001b[0m \u001b[38;5;66;03m# We have to explicitly murder all open objects related to the file\u001b[39;00m\n\u001b[0;32m 577\u001b[0m \n\u001b[0;32m 578\u001b[0m \u001b[38;5;66;03m# Close file-resident objects first, then the files.\u001b[39;00m\n\u001b[0;32m 579\u001b[0m \u001b[38;5;66;03m# Otherwise we get errors in MPI mode.\u001b[39;00m\n\u001b[0;32m 580\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39m_close_open_objects(h5f\u001b[38;5;241m.\u001b[39mOBJ_LOCAL \u001b[38;5;241m|\u001b[39m \u001b[38;5;241m~\u001b[39mh5f\u001b[38;5;241m.\u001b[39mOBJ_FILE)\n\u001b[1;32m--> 581\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_close_open_objects\u001b[49m\u001b[43m(\u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_LOCAL\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m|\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_FILE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mclose()\n\u001b[0;32m 584\u001b[0m _objects\u001b[38;5;241m.\u001b[39mnonlocal_close()\n",
293
+ "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
294
+ "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
295
+ "File \u001b[1;32mh5py\\\\h5f.pyx:355\u001b[0m, in \u001b[0;36mh5py.h5f.FileID._close_open_objects\u001b[1;34m()\u001b[0m\n",
296
+ "\u001b[1;31mRuntimeError\u001b[0m: Can't decrement id ref count (unable to extend file properly)"
297
+ ]
298
+ }
299
+ ],
300
+ "source": [
301
+ "import pandas as pd\n",
302
+ "import numpy as np\n",
303
+ "from tqdm import tqdm\n",
304
+ "import random\n",
305
+ "import os\n",
306
+ "import ktrain\n",
307
+ "from ktrain import text\n",
308
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
309
+ "\n",
310
+ "\n",
311
+ "# PAM1\n",
312
+ "# PAM matrix model of protein evolution\n",
313
+ "# DOI:10.1093/oxfordjournals.molbev.a040360\n",
314
+ "pam_data = {\n",
315
+ " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
316
+ " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
317
+ " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
318
+ " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
319
+ " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
320
+ " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
321
+ " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
322
+ " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
323
+ " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
324
+ " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
325
+ " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
326
+ " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
327
+ " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
328
+ " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
329
+ " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
330
+ " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
331
+ " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
332
+ " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
333
+ " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
334
+ " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
335
+ "}\n",
336
+ "pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
337
+ "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
338
+ "list_amino = pam_raw.columns.tolist()\n",
339
+ "pam_dict = {\n",
340
+ " aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
341
+ " for aa in list_amino\n",
342
+ "}\n",
343
+ "\n",
344
+ "def pam1_substitution(aa):\n",
345
+ " if aa not in pam_dict:\n",
346
+ " return aa\n",
347
+ " subs = list(pam_dict[aa].keys())\n",
348
+ " probs = list(pam_dict[aa].values())\n",
349
+ " return np.random.choice(subs, p=probs)\n",
350
+ "\n",
351
+ "def augment_sequence(seq, sub_prob=0.05):\n",
352
+ " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
353
+ "\n",
354
+ "def slice_sequence(seq, win=500, min_overlap=250):\n",
355
+ " if len(seq) <= win:\n",
356
+ " return [seq]\n",
357
+ " slices, start = [], 0\n",
358
+ " while start + win <= len(seq):\n",
359
+ " slices.append(seq[start:start+win])\n",
360
+ " start += win\n",
361
+ " leftover = seq[start:]\n",
362
+ " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
363
+ " extra = slices[-1][-min_overlap:] + leftover\n",
364
+ " slices.append(extra)\n",
365
+ " return slices\n",
366
+ "\n",
367
+ "def generate_data(df, augment=False):\n",
368
+ " X, y = [], []\n",
369
+ " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
370
+ " for _, row in tqdm(df.iterrows(), total=len(df)):\n",
371
+ " seq = row[\"sequence\"]\n",
372
+ " if augment:\n",
373
+ " seq = augment_sequence(seq)\n",
374
+ " seq_slices = slice_sequence(seq)\n",
375
+ " X.extend(seq_slices)\n",
376
+ " lbl = row[label_cols].values.astype(int)\n",
377
+ " y.extend([lbl] * len(seq_slices))\n",
378
+ " return X, np.array(y), label_cols\n",
379
+ "\n",
380
+ "def format_sequence(seq): return \" \".join(list(seq))\n",
381
+ "\n",
382
+ "# Função para carregar e binarizar\n",
383
+ "def load_and_binarize(csv_path, mlb=None):\n",
384
+ " df = pd.read_csv(csv_path)\n",
385
+ " df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
386
+ " if mlb is None:\n",
387
+ " mlb = MultiLabelBinarizer()\n",
388
+ " labels = mlb.fit_transform(df[\"go_terms\"])\n",
389
+ " else:\n",
390
+ " labels = mlb.transform(df[\"go_terms\"])\n",
391
+ " labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
392
+ " df = df.reset_index(drop=True).join(labels_df)\n",
393
+ " return df, mlb\n",
394
+ "\n",
395
+ "# Carregar os dados\n",
396
+ "df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
397
+ "df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
398
+ "\n",
399
+ "# Gerar com augmentation no treino\n",
400
+ "X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
401
+ "X_val, y_val, _ = generate_data(df_val, augment=False)\n",
402
+ "\n",
403
+ "# Preparar texto para tokenizer\n",
404
+ "X_train_fmt = list(map(format_sequence, X_train))\n",
405
+ "X_val_fmt = list(map(format_sequence, X_val))\n",
406
+ "\n",
407
+ "# Fine-tune ProtBERT\n",
408
+ "# https://huggingface.co/Rostlab/prot_bert\n",
409
+ "# https://doi.org/10.1093/bioinformatics/btac020\n",
410
+ "# dados de treino-> UniRef100 (216 milhões de sequências)\n",
411
+ "MODEL_NAME = \"Rostlab/prot_bert\"\n",
412
+ "MAX_LEN = 512\n",
413
+ "BATCH_SIZE = 1\n",
414
+ "\n",
415
+ "t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
416
+ "trn = t.preprocess_train(X_train_fmt, y_train)\n",
417
+ "val = t.preprocess_test(X_val_fmt, y_val)\n",
418
+ "\n",
419
+ "model = t.get_classifier()\n",
420
+ "learner = ktrain.get_learner(model,\n",
421
+ " train_data=trn,\n",
422
+ " val_data=val,\n",
423
+ " batch_size=BATCH_SIZE)\n",
424
+ "\n",
425
+ "learner.autofit(lr=1e-5,\n",
426
+ " epochs=10,\n",
427
+ " early_stopping=1,\n",
428
+ " checkpoint_folder=\"mf-fine-tuned-protbert\")\n"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": 7,
434
+ "id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
435
+ "metadata": {},
436
+ "outputs": [
437
+ {
438
+ "name": "stdout",
439
+ "output_type": "stream",
440
+ "text": [
441
+ "✅ Existe: weights/mf-fine-tuned-protbert-epoch10\n",
442
+ "📁 Conteúdo:\n",
443
+ " - config.json\n",
444
+ " - tf_model.h5\n"
445
+ ]
446
+ }
447
+ ],
448
+ "source": [
449
+ "import os\n",
450
+ "\n",
451
+ "path = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
452
+ "\n",
453
+ "if os.path.exists(path):\n",
454
+ " print(f\"✅ Existe: {path}\")\n",
455
+ " print(\"📁 Conteúdo:\")\n",
456
+ " for f in os.listdir(path):\n",
457
+ " print(\" -\", f)\n",
458
+ "else:\n",
459
+ " print(f\"❌ Não existe: {path}\")\n",
460
+ "\n"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": 19,
466
+ "id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
467
+ "metadata": {},
468
+ "outputs": [
469
+ {
470
+ "name": "stderr",
471
+ "output_type": "stream",
472
+ "text": [
473
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
474
+ " from .autonotebook import tqdm as notebook_tqdm\n",
475
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
476
+ " _torch_pytree._register_pytree_node(\n",
477
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
478
+ " warnings.warn(\n",
479
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:309: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
480
+ " _torch_pytree._register_pytree_node(\n",
481
+ "Some layers from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10 were not used when initializing TFBertModel: ['classifier', 'dropout_183']\n",
482
+ "- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
483
+ "- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
484
+ "All the layers of TFBertModel were initialized from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10.\n",
485
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
486
+ ]
487
+ },
488
+ {
489
+ "name": "stdout",
490
+ "output_type": "stream",
491
+ "text": [
492
+ "✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
493
+ ]
494
+ },
495
+ {
496
+ "name": "stderr",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "Processando data/mf-training.csv: 0%| | 25/31142 [00:06<2:23:28, 3.61it/s]\n"
500
+ ]
501
+ },
502
+ {
503
+ "ename": "KeyboardInterrupt",
504
+ "evalue": "",
505
+ "output_type": "error",
506
+ "traceback": [
507
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
508
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
509
+ "Cell \u001b[1;32mIn[19], line 78\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[38;5;66;03m# --- 4. Aplicar -----------------------------------------------------------\u001b[39;00m\n\u001b[0;32m 76\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(OUT_DIR, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m---> 78\u001b[0m \u001b[43mprocess_split\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata/mf-training.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mOUT_DIR\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain_protbert.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 79\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-validation.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m 80\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-test.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
510
+ "Cell \u001b[1;32mIn[19], line 61\u001b[0m, in \u001b[0;36mprocess_split\u001b[1;34m(csv_path, out_path)\u001b[0m\n\u001b[0;32m 59\u001b[0m embeds\u001b[38;5;241m.\u001b[39mappend(prot_embed\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[0;32m 60\u001b[0m labels\u001b[38;5;241m.\u001b[39mappend(row[label_cols]\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mint8))\n\u001b[1;32m---> 61\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[0;32m 63\u001b[0m embeds \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(embeds)\n\u001b[0;32m 64\u001b[0m labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(labels)\n",
511
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
512
+ ]
513
+ }
514
+ ],
515
+ "source": [
516
+ "import os\n",
517
+ "import pandas as pd\n",
518
+ "import numpy as np\n",
519
+ "from tqdm import tqdm\n",
520
+ "import joblib\n",
521
+ "import gc\n",
522
+ "from transformers import AutoTokenizer, TFAutoModel\n",
523
+ "\n",
524
+ "# --- 1. Parâmetros --------------------------------------------------------\n",
525
+ "MODEL_DIR = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
526
+ "BASE_MODEL = \"Rostlab/prot_bert\"\n",
527
+ "OUT_DIR = \"embeddings\"\n",
528
+ "BATCH_TOK = 16\n",
529
+ "\n",
530
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, do_lower_case=False)\n",
531
+ "model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
532
+ "\n",
533
+ "print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
534
+ "\n",
535
+ "# --- 3. Funções auxiliares ------------------------------------------------\n",
536
+ "def format_sequence(seq):\n",
537
+ " return \" \".join(list(seq))\n",
538
+ "\n",
539
+ "def slice_sequence(seq, win=500, min_overlap=250):\n",
540
+ " if len(seq) <= win:\n",
541
+ " return [seq]\n",
542
+ " slices, start = [], 0\n",
543
+ " while start + win <= len(seq):\n",
544
+ " slices.append(seq[start:start+win])\n",
545
+ " start += win\n",
546
+ " leftover = seq[start:]\n",
547
+ " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
548
+ " extra = slices[-1][-min_overlap:] + leftover\n",
549
+ " slices.append(extra)\n",
550
+ " return slices\n",
551
+ "\n",
552
+ "def get_embeddings(batch, tokenizer, model):\n",
553
+ " tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
554
+ " output = model(**tokens)\n",
555
+ " return output.last_hidden_state[:, 0, :].numpy()\n",
556
+ "\n",
557
+ "def process_split(csv_path, out_path):\n",
558
+ " df = pd.read_csv(csv_path)\n",
559
+ " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
560
+ " prot_ids, embeds, labels = [], [], []\n",
561
+ "\n",
562
+ " for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
563
+ " slices = slice_sequence(row[\"sequence\"])\n",
564
+ " slices_fmt = list(map(format_sequence, slices))\n",
565
+ "\n",
566
+ " slice_embeds = []\n",
567
+ " for i in range(0, len(slices_fmt), BATCH_TOK):\n",
568
+ " batch = slices_fmt[i:i+BATCH_TOK]\n",
569
+ " slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
570
+ " slice_embeds = np.vstack(slice_embeds)\n",
571
+ "\n",
572
+ " prot_embed = slice_embeds.mean(axis=0)\n",
573
+ " prot_ids.append(row[\"protein_id\"])\n",
574
+ " embeds.append(prot_embed.astype(np.float32))\n",
575
+ " labels.append(row[label_cols].values.astype(np.int8))\n",
576
+ " gc.collect()\n",
577
+ "\n",
578
+ " embeds = np.vstack(embeds)\n",
579
+ " labels = np.vstack(labels)\n",
580
+ "\n",
581
+ " joblib.dump({\n",
582
+ " \"protein_ids\": prot_ids,\n",
583
+ " \"embeddings\": embeds,\n",
584
+ " \"labels\": labels,\n",
585
+ " \"go_terms\": label_cols\n",
586
+ " }, out_path, compress=3)\n",
587
+ "\n",
588
+ " print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
589
+ "\n",
590
+ "# --- 4. Aplicar -----------------------------------------------------------\n",
591
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
592
+ "\n",
593
+ "process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbert.pkl\"))\n",
594
+ "process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbert.pkl\"))\n",
595
+ "process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbert.pkl\"))\n"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "code",
600
+ "execution_count": 27,
601
+ "id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
602
+ "metadata": {},
603
+ "outputs": [
604
+ {
605
+ "name": "stdout",
606
+ "output_type": "stream",
607
+ "text": [
608
+ "✓ Corrigido: embeddings/train_protbert.pkl — 31142 exemplos, 597 GO terms\n",
609
+ "✓ Corrigido: embeddings/val_protbert.pkl — 1724 exemplos, 597 GO terms\n",
610
+ "✓ Corrigido: embeddings/test_protbert.pkl — 1724 exemplos, 597 GO terms\n"
611
+ ]
612
+ }
613
+ ],
614
+ "source": [
615
+ "import pandas as pd\n",
616
+ "import joblib\n",
617
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
618
+ "\n",
619
+ "# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
620
+ "df_test = pd.read_csv(\"data/mf-test.csv\")\n",
621
+ "test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
622
+ "\n",
623
+ "# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
624
+ "def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
625
+ " df = pd.read_csv(csv_path)\n",
626
+ " terms_split = df[\"go_terms\"].str.split(\";\")\n",
627
+ " \n",
628
+ " # Apenas termos presentes nos common_terms\n",
629
+ " terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
630
+ " \n",
631
+ " mlb = MultiLabelBinarizer(classes=common_terms)\n",
632
+ " Y = mlb.fit_transform(terms_filtered)\n",
633
+ "\n",
634
+ " data = joblib.load(pkl_path)\n",
635
+ " data[\"labels\"] = Y\n",
636
+ " data[\"go_terms\"] = mlb.classes_.tolist()\n",
637
+ " \n",
638
+ " joblib.dump(data, pkl_path, compress=3)\n",
639
+ " print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
640
+ "\n",
641
+ "# --- 3. Aplicar às 3 partições --------------------------------------------\n",
642
+ "patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbert.pkl\", test_terms)\n",
643
+ "patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbert.pkl\", test_terms)\n",
644
+ "patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbert.pkl\", test_terms)\n"
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": 1,
650
+ "id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
651
+ "metadata": {},
652
+ "outputs": [],
653
+ "source": [
654
+ "import joblib\n",
655
+ "train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
656
+ "val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
657
+ "test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
658
+ "\n",
659
+ "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
660
+ "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
661
+ "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "execution_count": 3,
667
+ "id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
668
+ "metadata": {},
669
+ "outputs": [
670
+ {
671
+ "name": "stdout",
672
+ "output_type": "stream",
673
+ "text": [
674
+ "✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
675
+ "Epoch 1/100\n",
676
+ "974/974 [==============================] - 4s 3ms/step - loss: 0.0358 - binary_accuracy: 0.9893 - val_loss: 0.0336 - val_binary_accuracy: 0.9901\n",
677
+ "Epoch 2/100\n",
678
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0276 - binary_accuracy: 0.9914 - val_loss: 0.0331 - val_binary_accuracy: 0.9902\n",
679
+ "Epoch 3/100\n",
680
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0268 - binary_accuracy: 0.9916 - val_loss: 0.0330 - val_binary_accuracy: 0.9902\n",
681
+ "Epoch 4/100\n",
682
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0264 - binary_accuracy: 0.9917 - val_loss: 0.0320 - val_binary_accuracy: 0.9904\n",
683
+ "Epoch 5/100\n",
684
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0260 - binary_accuracy: 0.9917 - val_loss: 0.0319 - val_binary_accuracy: 0.9904\n",
685
+ "Epoch 6/100\n",
686
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0256 - binary_accuracy: 0.9918 - val_loss: 0.0322 - val_binary_accuracy: 0.9904\n",
687
+ "Epoch 7/100\n",
688
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0255 - binary_accuracy: 0.9918 - val_loss: 0.0317 - val_binary_accuracy: 0.9903\n",
689
+ "Epoch 8/100\n",
690
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0252 - binary_accuracy: 0.9919 - val_loss: 0.0320 - val_binary_accuracy: 0.9905\n",
691
+ "Epoch 9/100\n",
692
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0251 - binary_accuracy: 0.9919 - val_loss: 0.0316 - val_binary_accuracy: 0.9904\n",
693
+ "Epoch 10/100\n",
694
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0250 - binary_accuracy: 0.9920 - val_loss: 0.0314 - val_binary_accuracy: 0.9905\n",
695
+ "Epoch 11/100\n",
696
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0248 - binary_accuracy: 0.9920 - val_loss: 0.0317 - val_binary_accuracy: 0.9905\n",
697
+ "Epoch 12/100\n",
698
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0247 - binary_accuracy: 0.9920 - val_loss: 0.0315 - val_binary_accuracy: 0.9905\n",
699
+ "Epoch 13/100\n",
700
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0246 - binary_accuracy: 0.9920 - val_loss: 0.0322 - val_binary_accuracy: 0.9904\n",
701
+ "Epoch 14/100\n",
702
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0245 - binary_accuracy: 0.9920 - val_loss: 0.0319 - val_binary_accuracy: 0.9905\n",
703
+ "Epoch 15/100\n",
704
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0244 - binary_accuracy: 0.9920 - val_loss: 0.0319 - val_binary_accuracy: 0.9906\n",
705
+ "Previsões guardadas em mf-protbert-pam1.npy\n",
706
+ "Modelo guardado em models/protbert_mlp.keras\n"
707
+ ]
708
+ }
709
+ ],
710
+ "source": [
711
+ "import tensorflow as tf\n",
712
+ "import joblib\n",
713
+ "import numpy as np\n",
714
+ "from tensorflow.keras.models import Sequential\n",
715
+ "from tensorflow.keras.layers import Dense, Dropout\n",
716
+ "from tensorflow.keras.callbacks import EarlyStopping\n",
717
+ "\n",
718
+ "# --- 1. Carregar embeddings ----------------------------------------------\n",
719
+ "train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
720
+ "val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
721
+ "test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
722
+ "\n",
723
+ "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
724
+ "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
725
+ "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
726
+ "\n",
727
+ "print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
728
+ "\n",
729
+ "# --- 2. Garantir consistência de classes ---------------------------------\n",
730
+ "max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
731
+ "\n",
732
+ "def pad_labels(y, target_dim=max_classes):\n",
733
+ " if y.shape[1] < target_dim:\n",
734
+ " padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
735
+ " return np.hstack([y, padding])\n",
736
+ " return y\n",
737
+ "\n",
738
+ "y_val = pad_labels(y_val)\n",
739
+ "y_test = pad_labels(y_test)\n",
740
+ "\n",
741
+ "# --- 3. Modelo MLP ------------------------------------------------------\n",
742
+ "model = Sequential([\n",
743
+ " Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
744
+ " Dropout(0.3),\n",
745
+ " Dense(512, activation=\"relu\"),\n",
746
+ " Dropout(0.3),\n",
747
+ " Dense(max_classes, activation=\"sigmoid\")\n",
748
+ "])\n",
749
+ "\n",
750
+ "model.compile(loss=\"binary_crossentropy\",\n",
751
+ " optimizer=\"adam\",\n",
752
+ " metrics=[\"binary_accuracy\"])\n",
753
+ "\n",
754
+ "# --- 4. Early stopping e treino -----------------------------------------\n",
755
+ "callbacks = [\n",
756
+ " EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
757
+ "]\n",
758
+ "\n",
759
+ "model.fit(X_train, y_train,\n",
760
+ " validation_data=(X_val, y_val),\n",
761
+ " epochs=100,\n",
762
+ " batch_size=32,\n",
763
+ " callbacks=callbacks,\n",
764
+ " verbose=1)\n",
765
+ "\n",
766
+ "# --- 5. Previsões --------------------------------------------------------\n",
767
+ "y_prob = model.predict(X_test)\n",
768
+ "np.save(\"predictions/mf-protbert-pam1.npy\", y_prob)\n",
769
+ "print(\"Previsões guardadas em mf-protbert-pam1.npy\")\n",
770
+ "\n",
771
+ "# --- 6. Modelo ----------------------------------------------------------\n",
772
+ "model.save(\"models/protbert_mlp.keras\")\n",
773
+ "print(\"Modelo guardado em models/protbert_mlp.keras\")"
774
+ ]
775
+ },
776
+ {
777
+ "cell_type": "code",
778
+ "execution_count": 30,
779
+ "id": "fdb66630-76dc-43a0-bd56-45052175fdba",
780
+ "metadata": {},
781
+ "outputs": [
782
+ {
783
+ "name": "stdout",
784
+ "output_type": "stream",
785
+ "text": [
786
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
787
+ "✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
788
+ "\n",
789
+ "📊 Resultados finais (ProtBERT + PAM1 + propagação):\n",
790
+ "Fmax = 0.6666\n",
791
+ "Thr. = 0.50\n",
792
+ "AuPRC = 0.7028\n",
793
+ "Smin = 13.1745\n"
794
+ ]
795
+ }
796
+ ],
797
+ "source": [
798
+ "import numpy as np\n",
799
+ "from sklearn.metrics import precision_recall_curve, auc\n",
800
+ "from goatools.obo_parser import GODag\n",
801
+ "import joblib\n",
802
+ "import math\n",
803
+ "\n",
804
+ "# --- 1. Parâmetros -------------------------------------------------------\n",
805
+ "GO_FILE = \"go.obo\"\n",
806
+ "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
807
+ "ALPHA = 0.5\n",
808
+ "\n",
809
+ "# --- 2. Carregar dados ---------------------------------------------------\n",
810
+ "test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
811
+ "y_true = test[\"labels\"]\n",
812
+ "terms = test[\"go_terms\"]\n",
813
+ "y_prob = np.load(\"predictions/mf-protbert-pam1.npy\")\n",
814
+ "go_dag = GODag(GO_FILE)\n",
815
+ "\n",
816
+ "print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
817
+ "\n",
818
+ "# --- 3. Fmax -------------------------------------------------------------\n",
819
+ "def compute_fmax(y_true, y_prob, thresholds):\n",
820
+ " fmax, best_thr = 0, 0\n",
821
+ " for t in thresholds:\n",
822
+ " y_pred = (y_prob >= t).astype(int)\n",
823
+ " tp = (y_true * y_pred).sum(axis=1)\n",
824
+ " fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
825
+ " fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
826
+ " precision = tp / (tp + fp + 1e-8)\n",
827
+ " recall = tp / (tp + fn + 1e-8)\n",
828
+ " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
829
+ " avg_f1 = np.mean(f1)\n",
830
+ " if avg_f1 > fmax:\n",
831
+ " fmax, best_thr = avg_f1, t\n",
832
+ " return fmax, best_thr\n",
833
+ "\n",
834
+ "# --- 4. AuPRC micro ------------------------------------------------------\n",
835
+ "def compute_auprc(y_true, y_prob):\n",
836
+ " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
837
+ " return auc(recall, precision)\n",
838
+ "\n",
839
+ "# --- 5. Smin -------------------------------------------------------------\n",
840
+ "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
841
+ " y_pred = (y_prob >= threshold).astype(int)\n",
842
+ " ic = {}\n",
843
+ " total = (y_true + y_pred).sum(axis=0).sum()\n",
844
+ " for i, term in enumerate(terms):\n",
845
+ " freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
846
+ " ic[term] = -np.log((freq + 1e-8) / total)\n",
847
+ "\n",
848
+ " s_values = []\n",
849
+ " for true_vec, pred_vec in zip(y_true, y_pred):\n",
850
+ " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
851
+ " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
852
+ "\n",
853
+ " anc_true = set()\n",
854
+ " for t in true_terms:\n",
855
+ " if t in go_dag:\n",
856
+ " anc_true |= go_dag[t].get_all_parents()\n",
857
+ " anc_pred = set()\n",
858
+ " for t in pred_terms:\n",
859
+ " if t in go_dag:\n",
860
+ " anc_pred |= go_dag[t].get_all_parents()\n",
861
+ "\n",
862
+ " ru = pred_terms - true_terms\n",
863
+ " mi = true_terms - pred_terms\n",
864
+ " dist_ru = sum(ic.get(t, 0) for t in ru)\n",
865
+ " dist_mi = sum(ic.get(t, 0) for t in mi)\n",
866
+ " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
867
+ " s_values.append(s)\n",
868
+ "\n",
869
+ " return np.mean(s_values)\n",
870
+ "\n",
871
+ "# --- 6. Avaliar ----------------------------------------------------------\n",
872
+ "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
873
+ "auprc = compute_auprc(y_true, y_prob)\n",
874
+ "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
875
+ "\n",
876
+ "print(f\"\\n📊 Resultados finais (ProtBERT + PAM1 + propagação):\")\n",
877
+ "print(f\"Fmax = {fmax:.4f}\")\n",
878
+ "print(f\"Thr. = {thr:.2f}\")\n",
879
+ "print(f\"AuPRC = {auprc:.4f}\")\n",
880
+ "print(f\"Smin = {smin:.4f}\")\n"
881
+ ]
882
+ },
883
+ {
884
+ "cell_type": "code",
885
+ "execution_count": 3,
886
+ "id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
887
+ "metadata": {},
888
+ "outputs": [
889
+ {
890
+ "data": {
891
+ "text/plain": [
892
+ "['data/mlb_protbert.pkl']"
893
+ ]
894
+ },
895
+ "execution_count": 3,
896
+ "metadata": {},
897
+ "output_type": "execute_result"
898
+ }
899
+ ],
900
+ "source": [
901
+ "import joblib, pickle\n",
902
+ "joblib.dump(mlb, \"data/mlb_protbert.pkl\")"
903
+ ]
904
+ },
905
+ {
906
+ "cell_type": "code",
907
+ "execution_count": null,
908
+ "id": "9f89c3bc-6b78-4a4c-8ddd-b69c7d3d0e65",
909
+ "metadata": {},
910
+ "outputs": [],
911
+ "source": []
912
+ }
913
+ ],
914
+ "metadata": {
915
+ "kernelspec": {
916
+ "display_name": "Python 3 (ipykernel)",
917
+ "language": "python",
918
+ "name": "python3"
919
+ },
920
+ "language_info": {
921
+ "codemirror_mode": {
922
+ "name": "ipython",
923
+ "version": 3
924
+ },
925
+ "file_extension": ".py",
926
+ "mimetype": "text/x-python",
927
+ "name": "python",
928
+ "nbconvert_exporter": "python",
929
+ "pygments_lexer": "ipython3",
930
+ "version": "3.8.18"
931
+ }
932
+ },
933
+ "nbformat": 4,
934
+ "nbformat_minor": 5
935
+ }
notebooks/PAM1_protbertBFD.ipynb ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
+ "✓ Ficheiros criados:\n",
15
+ " - data/mf-training.csv : (31142, 3)\n",
16
+ " - data/mf-validation.csv: (1724, 3)\n",
17
+ " - data/mf-test.csv : (1724, 3)\n",
18
+ "GO terms únicos (após propagação e filtro): 602\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "import pandas as pd\n",
24
+ "from Bio import SeqIO\n",
25
+ "from collections import Counter\n",
26
+ "from goatools.obo_parser import GODag\n",
27
+ "from sklearn.model_selection import train_test_split\n",
28
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
29
+ "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
30
+ "import numpy as np\n",
31
+ "import os\n",
32
+ "\n",
33
+ "# --- 1. Carregar GO anotações ------------------------------------------\n",
34
+ "annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
35
+ "annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
36
+ "\n",
37
+ "# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
38
+ "# propagação hierárquica\n",
39
+ "# https://geneontology.org/docs/download-ontology/\n",
40
+ "go_dag = GODag(\"go.obo\")\n",
41
+ "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
42
+ "\n",
43
+ "def propagate_terms(term_list):\n",
44
+ " full = set()\n",
45
+ " for t in term_list:\n",
46
+ " if t not in go_dag:\n",
47
+ " continue\n",
48
+ " full.add(t)\n",
49
+ " full.update(go_dag[t].get_all_parents())\n",
50
+ " return list(full & mf_terms)\n",
51
+ "\n",
52
+ "# --- 3. Carregar sequências --------------------------------------------\n",
53
+ "seqs, ids = [], []\n",
54
+ "for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
55
+ " ids.append(record.id)\n",
56
+ " seqs.append(str(record.seq))\n",
57
+ "\n",
58
+ "seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
59
+ "\n",
60
+ "# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
61
+ "grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
62
+ "data = seq_df.merge(grouped, on=\"protein_id\")\n",
63
+ "data = data[data[\"go_term\"].apply(len) > 0]\n",
64
+ "data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
65
+ "data = data[data[\"go_term\"].apply(len) > 0]\n",
66
+ "\n",
67
+ "# --- 5. Filtrar GO terms raros -----------------------------------------\n",
68
+ "# todos os terms com menos de 50 proteinas associadas\n",
69
+ "all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
70
+ "term_counts = Counter(all_terms)\n",
71
+ "valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
72
+ "data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
73
+ "data = data[data[\"go_term\"].apply(len) > 0]\n",
74
+ "\n",
75
+ "# --- 6. Preparar dataset final -----------------------------------------\n",
76
+ "data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
77
+ "data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
78
+ "\n",
79
+ "# --- 7. Binarizar labels e dividir -------------------------------------\n",
80
+ "mlb = MultiLabelBinarizer()\n",
81
+ "Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
82
+ "X = data[[\"protein_id\", \"sequence\"]].values\n",
83
+ "\n",
84
+ "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
85
+ "train_idx, temp_idx = next(mskf.split(X, Y))\n",
86
+ "val_idx, test_idx = np.array_split(temp_idx, 2)\n",
87
+ "\n",
88
+ "df_train = data.iloc[train_idx].copy()\n",
89
+ "df_val = data.iloc[val_idx].copy()\n",
90
+ "df_test = data.iloc[test_idx].copy()\n",
91
+ "\n",
92
+ "# --- 8. Guardar em CSV -------------------------------------------------\n",
93
+ "os.makedirs(\"data\", exist_ok=True)\n",
94
+ "df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
95
+ "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
96
+ "df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
97
+ "\n",
98
+ "# --- 9. Confirmar ------------------------------------------------------\n",
99
+ "print(\"✓ Ficheiros criados:\")\n",
100
+ "print(\" - data/mf-training.csv :\", df_train.shape)\n",
101
+ "print(\" - data/mf-validation.csv:\", df_val.shape)\n",
102
+ "print(\" - data/mf-test.csv :\", df_test.shape)\n",
103
+ "print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 2,
109
+ "id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "name": "stderr",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
117
+ " from .autonotebook import tqdm as notebook_tqdm\n",
118
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
119
+ " _torch_pytree._register_pytree_node(\n",
120
+ "100%|██████████| 31142/31142 [00:26<00:00, 1192.86it/s]\n",
121
+ "100%|██████████| 1724/1724 [00:00<00:00, 2570.68it/s]\n",
122
+ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
123
+ " warnings.warn(\n"
124
+ ]
125
+ },
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "preprocessing train...\n",
131
+ "language: en\n",
132
+ "train sequence lengths:\n",
133
+ "\tmean : 423\n",
134
+ "\t95percentile : 604\n",
135
+ "\t99percentile : 715\n"
136
+ ]
137
+ },
138
+ {
139
+ "data": {
140
+ "text/html": [
141
+ "\n",
142
+ "<style>\n",
143
+ " /* Turns off some styling */\n",
144
+ " progress {\n",
145
+ " /* gets rid of default border in Firefox and Opera. */\n",
146
+ " border: none;\n",
147
+ " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
148
+ " background-size: auto;\n",
149
+ " }\n",
150
+ " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
151
+ " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
152
+ " }\n",
153
+ " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
154
+ " background: #F44336;\n",
155
+ " }\n",
156
+ "</style>\n"
157
+ ],
158
+ "text/plain": [
159
+ "<IPython.core.display.HTML object>"
160
+ ]
161
+ },
162
+ "metadata": {},
163
+ "output_type": "display_data"
164
+ },
165
+ {
166
+ "data": {
167
+ "text/html": [],
168
+ "text/plain": [
169
+ "<IPython.core.display.HTML object>"
170
+ ]
171
+ },
172
+ "metadata": {},
173
+ "output_type": "display_data"
174
+ },
175
+ {
176
+ "name": "stdout",
177
+ "output_type": "stream",
178
+ "text": [
179
+ "Is Multi-Label? True\n",
180
+ "preprocessing test...\n",
181
+ "language: en\n",
182
+ "test sequence lengths:\n",
183
+ "\tmean : 408\n",
184
+ "\t95percentile : 603\n",
185
+ "\t99percentile : 714\n"
186
+ ]
187
+ },
188
+ {
189
+ "data": {
190
+ "text/html": [
191
+ "\n",
192
+ "<style>\n",
193
+ " /* Turns off some styling */\n",
194
+ " progress {\n",
195
+ " /* gets rid of default border in Firefox and Opera. */\n",
196
+ " border: none;\n",
197
+ " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
198
+ " background-size: auto;\n",
199
+ " }\n",
200
+ " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
201
+ " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
202
+ " }\n",
203
+ " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
204
+ " background: #F44336;\n",
205
+ " }\n",
206
+ "</style>\n"
207
+ ],
208
+ "text/plain": [
209
+ "<IPython.core.display.HTML object>"
210
+ ]
211
+ },
212
+ "metadata": {},
213
+ "output_type": "display_data"
214
+ },
215
+ {
216
+ "data": {
217
+ "text/html": [],
218
+ "text/plain": [
219
+ "<IPython.core.display.HTML object>"
220
+ ]
221
+ },
222
+ "metadata": {},
223
+ "output_type": "display_data"
224
+ },
225
+ {
226
+ "name": "stdout",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "\n",
230
+ "\n",
231
+ "begin training using triangular learning rate policy with max lr of 1e-05...\n",
232
+ "Epoch 1/10\n",
233
+ "40995/40995 [==============================] - 9020s 219ms/step - loss: 0.0740 - binary_accuracy: 0.9869 - val_loss: 0.0526 - val_binary_accuracy: 0.9866\n",
234
+ "Epoch 2/10\n",
235
+ "40995/40995 [==============================] - 8939s 218ms/step - loss: 0.0464 - binary_accuracy: 0.9877 - val_loss: 0.0457 - val_binary_accuracy: 0.9871\n",
236
+ "Epoch 3/10\n",
237
+ "40995/40995 [==============================] - 8881s 217ms/step - loss: 0.0413 - binary_accuracy: 0.9883 - val_loss: 0.0418 - val_binary_accuracy: 0.9877\n",
238
+ "Epoch 4/10\n",
239
+ "40995/40995 [==============================] - 10277s 251ms/step - loss: 0.0380 - binary_accuracy: 0.9888 - val_loss: 0.0396 - val_binary_accuracy: 0.9881\n",
240
+ "Epoch 5/10\n",
241
+ "40995/40995 [==============================] - 10565s 258ms/step - loss: 0.0357 - binary_accuracy: 0.9892 - val_loss: 0.0380 - val_binary_accuracy: 0.9883\n",
242
+ "Epoch 6/10\n",
243
+ "40995/40995 [==============================] - 10693s 261ms/step - loss: 0.0338 - binary_accuracy: 0.9895 - val_loss: 0.0369 - val_binary_accuracy: 0.9885\n",
244
+ "Epoch 7/10\n",
245
+ "40995/40995 [==============================] - 12055s 294ms/step - loss: 0.0323 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
246
+ "Epoch 8/10\n",
247
+ "40995/40995 [==============================] - 10225s 249ms/step - loss: 0.0309 - binary_accuracy: 0.9901 - val_loss: 0.0353 - val_binary_accuracy: 0.9890\n",
248
+ "Epoch 9/10\n",
249
+ "40995/40995 [==============================] - 10308s 251ms/step - loss: 0.0297 - binary_accuracy: 0.9904 - val_loss: 0.0347 - val_binary_accuracy: 0.9891\n",
250
+ "Epoch 10/10\n",
251
+ "40995/40995 [==============================] - 10275s 251ms/step - loss: 0.0286 - binary_accuracy: 0.9907 - val_loss: 0.0346 - val_binary_accuracy: 0.9893\n",
252
+ "Weights from best epoch have been loaded into model.\n"
253
+ ]
254
+ },
255
+ {
256
+ "data": {
257
+ "text/plain": [
258
+ "<keras.callbacks.History at 0x2b644b84fd0>"
259
+ ]
260
+ },
261
+ "execution_count": 2,
262
+ "metadata": {},
263
+ "output_type": "execute_result"
264
+ }
265
+ ],
266
+ "source": [
267
+ "import pandas as pd\n",
268
+ "import numpy as np\n",
269
+ "from tqdm import tqdm\n",
270
+ "import random\n",
271
+ "import os\n",
272
+ "import ktrain\n",
273
+ "from ktrain import text\n",
274
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
275
+ "\n",
276
+ "\n",
277
+ "# PAM1\n",
278
+ "# PAM matrix model of protein evolution\n",
279
+ "# DOI:10.1093/oxfordjournals.molbev.a040360\n",
280
+ "pam_data = {\n",
281
+ " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
282
+ " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
283
+ " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
284
+ " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
285
+ " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
286
+ " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
287
+ " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
288
+ " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
289
+ " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
290
+ " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
291
+ " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
292
+ " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
293
+ " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
294
+ " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
295
+ " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
296
+ " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
297
+ " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
298
+ " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
299
+ " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
300
+ " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
301
+ "}\n",
302
+ "pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
303
+ "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
304
+ "list_amino = pam_raw.columns.tolist()\n",
305
+ "pam_dict = {\n",
306
+ " aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
307
+ " for aa in list_amino\n",
308
+ "}\n",
309
+ "\n",
310
+ "def pam1_substitution(aa):\n",
311
+ " if aa not in pam_dict:\n",
312
+ " return aa\n",
313
+ " subs = list(pam_dict[aa].keys())\n",
314
+ " probs = list(pam_dict[aa].values())\n",
315
+ " return np.random.choice(subs, p=probs)\n",
316
+ "\n",
317
+ "def augment_sequence(seq, sub_prob=0.05):\n",
318
+ " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
319
+ "\n",
320
+ "def slice_sequence(seq, win=500, min_overlap=250):\n",
321
+ " if len(seq) <= win:\n",
322
+ " return [seq]\n",
323
+ " slices, start = [], 0\n",
324
+ " while start + win <= len(seq):\n",
325
+ " slices.append(seq[start:start+win])\n",
326
+ " start += win\n",
327
+ " leftover = seq[start:]\n",
328
+ " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
329
+ " extra = slices[-1][-min_overlap:] + leftover\n",
330
+ " slices.append(extra)\n",
331
+ " return slices\n",
332
+ "\n",
333
+ "def generate_data(df, augment=False):\n",
334
+ " X, y = [], []\n",
335
+ " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
336
+ " for _, row in tqdm(df.iterrows(), total=len(df)):\n",
337
+ " seq = row[\"sequence\"]\n",
338
+ " if augment:\n",
339
+ " seq = augment_sequence(seq)\n",
340
+ " seq_slices = slice_sequence(seq)\n",
341
+ " X.extend(seq_slices)\n",
342
+ " lbl = row[label_cols].values.astype(int)\n",
343
+ " y.extend([lbl] * len(seq_slices))\n",
344
+ " return X, np.array(y), label_cols\n",
345
+ "\n",
346
+ "def format_sequence(seq): return \" \".join(list(seq))\n",
347
+ "\n",
348
+ "# Função para carregar e binarizar\n",
349
+ "def load_and_binarize(csv_path, mlb=None):\n",
350
+ " df = pd.read_csv(csv_path)\n",
351
+ " df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
352
+ " if mlb is None:\n",
353
+ " mlb = MultiLabelBinarizer()\n",
354
+ " labels = mlb.fit_transform(df[\"go_terms\"])\n",
355
+ " else:\n",
356
+ " labels = mlb.transform(df[\"go_terms\"])\n",
357
+ " labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
358
+ " df = df.reset_index(drop=True).join(labels_df)\n",
359
+ " return df, mlb\n",
360
+ "\n",
361
+ "# Carregar os dados\n",
362
+ "df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
363
+ "df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
364
+ "\n",
365
+ "# Gerar com augmentation no treino\n",
366
+ "X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
367
+ "X_val, y_val, _ = generate_data(df_val, augment=False)\n",
368
+ "\n",
369
+ "# Preparar texto para tokenizer\n",
370
+ "X_train_fmt = list(map(format_sequence, X_train))\n",
371
+ "X_val_fmt = list(map(format_sequence, X_val))\n",
372
+ "\n",
373
+ "# Fine-tune ProtBERT\n",
374
+ "# https://huggingface.co/Rostlab/prot_bert\n",
375
+ "# https://doi.org/10.1093/bioinformatics/btac020\n",
376
+ "# dados de treino-> UniRef100 (216 milhões de sequências)\n",
377
+ "MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
378
+ "MAX_LEN = 512\n",
379
+ "BATCH_SIZE = 1\n",
380
+ "\n",
381
+ "t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
382
+ "trn = t.preprocess_train(X_train_fmt, y_train)\n",
383
+ "val = t.preprocess_test(X_val_fmt, y_val)\n",
384
+ "\n",
385
+ "model = t.get_classifier()\n",
386
+ "learner = ktrain.get_learner(model,\n",
387
+ " train_data=trn,\n",
388
+ " val_data=val,\n",
389
+ " batch_size=BATCH_SIZE)\n",
390
+ "\n",
391
+ "learner.autofit(lr=1e-5,\n",
392
+ " epochs=10,\n",
393
+ " early_stopping=1,\n",
394
+ " checkpoint_folder=\"mf-fine-tuned-protbertbfd\")\n"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": 6,
400
+ "id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
401
+ "metadata": {},
402
+ "outputs": [
403
+ {
404
+ "name": "stdout",
405
+ "output_type": "stream",
406
+ "text": [
407
+ "✅ Existe: weights/mf-fine-tuned-protbertbfd\n",
408
+ "📁 Conteúdo:\n",
409
+ " - config.json\n",
410
+ " - tf_model.h5\n"
411
+ ]
412
+ }
413
+ ],
414
+ "source": [
415
+ "import os\n",
416
+ "learner.save_model('weights/mf-fine-tuned-protbertbfd')\n",
417
+ "path = \"weights/mf-fine-tuned-protbertbfd\"\n",
418
+ "\n",
419
+ "if os.path.exists(path):\n",
420
+ " print(f\"✅ Existe: {path}\")\n",
421
+ " print(\"📁 Conteúdo:\")\n",
422
+ " for f in os.listdir(path):\n",
423
+ " print(\" -\", f)\n",
424
+ "else:\n",
425
+ " print(f\"❌ Não existe: {path}\")\n",
426
+ "\n"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": 8,
432
+ "id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
433
+ "metadata": {},
434
+ "outputs": [
435
+ {
436
+ "name": "stdout",
437
+ "output_type": "stream",
438
+ "text": [
439
+ "✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
440
+ ]
441
+ },
442
+ {
443
+ "name": "stderr",
444
+ "output_type": "stream",
445
+ "text": [
446
+ "Processando data/mf-training.csv: 100%|██████████| 31142/31142 [5:17:56<00:00, 1.63it/s] \n"
447
+ ]
448
+ },
449
+ {
450
+ "name": "stdout",
451
+ "output_type": "stream",
452
+ "text": [
453
+ "✓ Guardado embeddings\\train_protbertbfd.pkl — 31142 proteínas\n"
454
+ ]
455
+ },
456
+ {
457
+ "name": "stderr",
458
+ "output_type": "stream",
459
+ "text": [
460
+ "Processando data/mf-validation.csv: 100%|██████████| 1724/1724 [19:15<00:00, 1.49it/s]\n"
461
+ ]
462
+ },
463
+ {
464
+ "name": "stdout",
465
+ "output_type": "stream",
466
+ "text": [
467
+ "✓ Guardado embeddings\\val_protbertbfd.pkl — 1724 proteínas\n"
468
+ ]
469
+ },
470
+ {
471
+ "name": "stderr",
472
+ "output_type": "stream",
473
+ "text": [
474
+ "Processando data/mf-test.csv: 100%|██████████| 1724/1724 [17:15<00:00, 1.66it/s]\n"
475
+ ]
476
+ },
477
+ {
478
+ "name": "stdout",
479
+ "output_type": "stream",
480
+ "text": [
481
+ "✓ Guardado embeddings\\test_protbertbfd.pkl — 1724 proteínas\n"
482
+ ]
483
+ }
484
+ ],
485
+ "source": [
486
+ "import os\n",
487
+ "import pandas as pd\n",
488
+ "import numpy as np\n",
489
+ "from tqdm import tqdm\n",
490
+ "import joblib\n",
491
+ "import gc\n",
492
+ "from transformers import AutoTokenizer, TFAutoModel\n",
493
+ "\n",
494
+ "# --- 1. Parâmetros --------------------------------------------------------\n",
495
+ "MODEL_DIR = \"weights/mf-fine-tuned-protbertbfd\"\n",
496
+ "MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
497
+ "OUT_DIR = \"embeddings\"\n",
498
+ "BATCH_TOK = 16\n",
499
+ "\n",
500
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
501
+ "model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
502
+ "\n",
503
+ "print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
504
+ "\n",
505
+ "# --- 3. Funções auxiliares ------------------------------------------------\n",
506
+ "def format_sequence(seq):\n",
507
+ " return \" \".join(list(seq))\n",
508
+ "\n",
509
+ "def slice_sequence(seq, win=500, min_overlap=250):\n",
510
+ " if len(seq) <= win:\n",
511
+ " return [seq]\n",
512
+ " slices, start = [], 0\n",
513
+ " while start + win <= len(seq):\n",
514
+ " slices.append(seq[start:start+win])\n",
515
+ " start += win\n",
516
+ " leftover = seq[start:]\n",
517
+ " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
518
+ " extra = slices[-1][-min_overlap:] + leftover\n",
519
+ " slices.append(extra)\n",
520
+ " return slices\n",
521
+ "\n",
522
+ "def get_embeddings(batch, tokenizer, model):\n",
523
+ " tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
524
+ " output = model(**tokens)\n",
525
+ " return output.last_hidden_state[:, 0, :].numpy()\n",
526
+ "\n",
527
+ "def process_split(csv_path, out_path):\n",
528
+ " df = pd.read_csv(csv_path)\n",
529
+ " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
530
+ " prot_ids, embeds, labels = [], [], []\n",
531
+ "\n",
532
+ " for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
533
+ " slices = slice_sequence(row[\"sequence\"])\n",
534
+ " slices_fmt = list(map(format_sequence, slices))\n",
535
+ "\n",
536
+ " slice_embeds = []\n",
537
+ " for i in range(0, len(slices_fmt), BATCH_TOK):\n",
538
+ " batch = slices_fmt[i:i+BATCH_TOK]\n",
539
+ " slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
540
+ " slice_embeds = np.vstack(slice_embeds)\n",
541
+ "\n",
542
+ " prot_embed = slice_embeds.mean(axis=0)\n",
543
+ " prot_ids.append(row[\"protein_id\"])\n",
544
+ " embeds.append(prot_embed.astype(np.float32))\n",
545
+ " labels.append(row[label_cols].values.astype(np.int8))\n",
546
+ " gc.collect()\n",
547
+ "\n",
548
+ " embeds = np.vstack(embeds)\n",
549
+ " labels = np.vstack(labels)\n",
550
+ "\n",
551
+ " joblib.dump({\n",
552
+ " \"protein_ids\": prot_ids,\n",
553
+ " \"embeddings\": embeds,\n",
554
+ " \"labels\": labels,\n",
555
+ " \"go_terms\": label_cols\n",
556
+ " }, out_path, compress=3)\n",
557
+ "\n",
558
+ " print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
559
+ "\n",
560
+ "# --- 4. Aplicar -----------------------------------------------------------\n",
561
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
562
+ "\n",
563
+ "process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbertbfd.pkl\"))\n",
564
+ "process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbertbfd.pkl\"))\n",
565
+ "process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbertbfd.pkl\"))\n"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": 9,
571
+ "id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
572
+ "metadata": {},
573
+ "outputs": [
574
+ {
575
+ "name": "stdout",
576
+ "output_type": "stream",
577
+ "text": [
578
+ "✓ Corrigido: embeddings/train_protbertbfd.pkl — 31142 exemplos, 597 GO terms\n",
579
+ "✓ Corrigido: embeddings/val_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n",
580
+ "✓ Corrigido: embeddings/test_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n"
581
+ ]
582
+ }
583
+ ],
584
+ "source": [
585
+ "import pandas as pd\n",
586
+ "import joblib\n",
587
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
588
+ "\n",
589
+ "# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
590
+ "df_test = pd.read_csv(\"data/mf-test.csv\")\n",
591
+ "test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
592
+ "\n",
593
+ "# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
594
+ "def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
595
+ " df = pd.read_csv(csv_path)\n",
596
+ " terms_split = df[\"go_terms\"].str.split(\";\")\n",
597
+ " \n",
598
+ " # Apenas termos presentes nos common_terms\n",
599
+ " terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
600
+ " \n",
601
+ " mlb = MultiLabelBinarizer(classes=common_terms)\n",
602
+ " Y = mlb.fit_transform(terms_filtered)\n",
603
+ "\n",
604
+ " data = joblib.load(pkl_path)\n",
605
+ " data[\"labels\"] = Y\n",
606
+ " data[\"go_terms\"] = mlb.classes_.tolist()\n",
607
+ " \n",
608
+ " joblib.dump(data, pkl_path, compress=3)\n",
609
+ " print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
610
+ "\n",
611
+ "# --- 3. Aplicar às 3 partições --------------------------------------------\n",
612
+ "patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbertbfd.pkl\", test_terms)\n",
613
+ "patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbertbfd.pkl\", test_terms)\n",
614
+ "patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbertbfd.pkl\", test_terms)\n"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": 2,
620
+ "id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
621
+ "metadata": {},
622
+ "outputs": [],
623
+ "source": [
624
+ "import joblib\n",
625
+ "train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
626
+ "val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
627
+ "test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
628
+ "\n",
629
+ "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
630
+ "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
631
+ "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": 5,
637
+ "id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
638
+ "metadata": {},
639
+ "outputs": [
640
+ {
641
+ "name": "stdout",
642
+ "output_type": "stream",
643
+ "text": [
644
+ "✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
645
+ "Epoch 1/100\n",
646
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0337 - binary_accuracy: 0.9901 - val_loss: 0.0331 - val_binary_accuracy: 0.9905\n",
647
+ "Epoch 2/100\n",
648
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0252 - binary_accuracy: 0.9921 - val_loss: 0.0326 - val_binary_accuracy: 0.9905\n",
649
+ "Epoch 3/100\n",
650
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0244 - binary_accuracy: 0.9924 - val_loss: 0.0330 - val_binary_accuracy: 0.9905\n",
651
+ "Epoch 4/100\n",
652
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0240 - binary_accuracy: 0.9925 - val_loss: 0.0322 - val_binary_accuracy: 0.9907\n",
653
+ "Epoch 5/100\n",
654
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0236 - binary_accuracy: 0.9925 - val_loss: 0.0328 - val_binary_accuracy: 0.9907\n",
655
+ "Epoch 6/100\n",
656
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0232 - binary_accuracy: 0.9926 - val_loss: 0.0325 - val_binary_accuracy: 0.9908\n",
657
+ "Epoch 7/100\n",
658
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0231 - binary_accuracy: 0.9926 - val_loss: 0.0325 - val_binary_accuracy: 0.9907\n",
659
+ "Epoch 8/100\n",
660
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0228 - binary_accuracy: 0.9927 - val_loss: 0.0326 - val_binary_accuracy: 0.9908\n",
661
+ "Epoch 9/100\n",
662
+ "974/974 [==============================] - 3s 3ms/step - loss: 0.0226 - binary_accuracy: 0.9927 - val_loss: 0.0326 - val_binary_accuracy: 0.9908\n",
663
+ "Previsões guardadas em mf-protbertbfd-pam1.npy\n",
664
+ "Modelo guardado em models/protbertbfd_mlp.keras\n"
665
+ ]
666
+ }
667
+ ],
668
+ "source": [
669
+ "import tensorflow as tf\n",
670
+ "import joblib\n",
671
+ "import numpy as np\n",
672
+ "from tensorflow.keras.models import Sequential\n",
673
+ "from tensorflow.keras.layers import Dense, Dropout\n",
674
+ "from tensorflow.keras.callbacks import EarlyStopping\n",
675
+ "\n",
676
+ "# --- 1. Carregar embeddings ----------------------------------------------\n",
677
+ "train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
678
+ "val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
679
+ "test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
680
+ "\n",
681
+ "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
682
+ "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
683
+ "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
684
+ "\n",
685
+ "print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
686
+ "\n",
687
+ "# --- 2. Garantir consistência de classes ---------------------------------\n",
688
+ "max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
689
+ "\n",
690
+ "def pad_labels(y, target_dim=max_classes):\n",
691
+ " if y.shape[1] < target_dim:\n",
692
+ " padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
693
+ " return np.hstack([y, padding])\n",
694
+ " return y\n",
695
+ "\n",
696
+ "y_val = pad_labels(y_val)\n",
697
+ "y_test = pad_labels(y_test)\n",
698
+ "\n",
699
+ "# --- 3. Modelo MLP ------------------------------------------------------\n",
700
+ "model = Sequential([\n",
701
+ " Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
702
+ " Dropout(0.3),\n",
703
+ " Dense(512, activation=\"relu\"),\n",
704
+ " Dropout(0.3),\n",
705
+ " Dense(max_classes, activation=\"sigmoid\")\n",
706
+ "])\n",
707
+ "\n",
708
+ "model.compile(loss=\"binary_crossentropy\",\n",
709
+ " optimizer=\"adam\",\n",
710
+ " metrics=[\"binary_accuracy\"])\n",
711
+ "\n",
712
+ "# --- 4. Early stopping e treino -----------------------------------------\n",
713
+ "callbacks = [\n",
714
+ " EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
715
+ "]\n",
716
+ "\n",
717
+ "model.fit(X_train, y_train,\n",
718
+ " validation_data=(X_val, y_val),\n",
719
+ " epochs=100,\n",
720
+ " batch_size=32,\n",
721
+ " callbacks=callbacks,\n",
722
+ " verbose=1)\n",
723
+ "\n",
724
+ "# --- 5. Previsões --------------------------------------------------------\n",
725
+ "y_prob = model.predict(X_test)\n",
726
+ "np.save(\"predictions/mf-protbertbfd-pam1.npy\", y_prob)\n",
727
+ "print(\"Previsões guardadas em mf-protbertbfd-pam1.npy\")\n",
728
+ "\n",
729
+ "# --- 6. Modelo ----------------------------------------------------------\n",
730
+ "model.save(\"models/protbertbfd_mlp.keras\")\n",
731
+ "print(\"Modelo guardado em models/protbertbfd_mlp.keras\")"
732
+ ]
733
+ },
734
+ {
735
+ "cell_type": "code",
736
+ "execution_count": 12,
737
+ "id": "fdb66630-76dc-43a0-bd56-45052175fdba",
738
+ "metadata": {},
739
+ "outputs": [
740
+ {
741
+ "name": "stdout",
742
+ "output_type": "stream",
743
+ "text": [
744
+ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
745
+ "✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
746
+ "\n",
747
+ "📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\n",
748
+ "Fmax = 0.6570\n",
749
+ "Thr. = 0.41\n",
750
+ "AuPRC = 0.6929\n",
751
+ "Smin = 13.8114\n"
752
+ ]
753
+ }
754
+ ],
755
+ "source": [
756
+ "import numpy as np\n",
757
+ "from sklearn.metrics import precision_recall_curve, auc\n",
758
+ "from goatools.obo_parser import GODag\n",
759
+ "import joblib\n",
760
+ "import math\n",
761
+ "\n",
762
+ "# --- 1. Parâmetros -------------------------------------------------------\n",
763
+ "GO_FILE = \"go.obo\"\n",
764
+ "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
765
+ "ALPHA = 0.5\n",
766
+ "\n",
767
+ "# --- 2. Carregar dados ---------------------------------------------------\n",
768
+ "test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
769
+ "y_true = test[\"labels\"]\n",
770
+ "terms = test[\"go_terms\"]\n",
771
+ "y_prob = np.load(\"predictions/mf-protbertbfd-pam1.npy\")\n",
772
+ "go_dag = GODag(GO_FILE)\n",
773
+ "\n",
774
+ "print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
775
+ "\n",
776
+ "# --- 3. Fmax -------------------------------------------------------------\n",
777
+ "def compute_fmax(y_true, y_prob, thresholds):\n",
778
+ " fmax, best_thr = 0, 0\n",
779
+ " for t in thresholds:\n",
780
+ " y_pred = (y_prob >= t).astype(int)\n",
781
+ " tp = (y_true * y_pred).sum(axis=1)\n",
782
+ " fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
783
+ " fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
784
+ " precision = tp / (tp + fp + 1e-8)\n",
785
+ " recall = tp / (tp + fn + 1e-8)\n",
786
+ " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
787
+ " avg_f1 = np.mean(f1)\n",
788
+ " if avg_f1 > fmax:\n",
789
+ " fmax, best_thr = avg_f1, t\n",
790
+ " return fmax, best_thr\n",
791
+ "\n",
792
+ "# --- 4. AuPRC micro ------------------------------------------------------\n",
793
+ "def compute_auprc(y_true, y_prob):\n",
794
+ " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
795
+ " return auc(recall, precision)\n",
796
+ "\n",
797
+ "# --- 5. Smin -------------------------------------------------------------\n",
798
+ "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
799
+ " y_pred = (y_prob >= threshold).astype(int)\n",
800
+ " ic = {}\n",
801
+ " total = (y_true + y_pred).sum(axis=0).sum()\n",
802
+ " for i, term in enumerate(terms):\n",
803
+ " freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
804
+ " ic[term] = -np.log((freq + 1e-8) / total)\n",
805
+ "\n",
806
+ " s_values = []\n",
807
+ " for true_vec, pred_vec in zip(y_true, y_pred):\n",
808
+ " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
809
+ " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
810
+ "\n",
811
+ " anc_true = set()\n",
812
+ " for t in true_terms:\n",
813
+ " if t in go_dag:\n",
814
+ " anc_true |= go_dag[t].get_all_parents()\n",
815
+ " anc_pred = set()\n",
816
+ " for t in pred_terms:\n",
817
+ " if t in go_dag:\n",
818
+ " anc_pred |= go_dag[t].get_all_parents()\n",
819
+ "\n",
820
+ " ru = pred_terms - true_terms\n",
821
+ " mi = true_terms - pred_terms\n",
822
+ " dist_ru = sum(ic.get(t, 0) for t in ru)\n",
823
+ " dist_mi = sum(ic.get(t, 0) for t in mi)\n",
824
+ " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
825
+ " s_values.append(s)\n",
826
+ "\n",
827
+ " return np.mean(s_values)\n",
828
+ "\n",
829
+ "# --- 6. Avaliar ----------------------------------------------------------\n",
830
+ "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
831
+ "auprc = compute_auprc(y_true, y_prob)\n",
832
+ "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
833
+ "\n",
834
+ "print(f\"\\n📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\")\n",
835
+ "print(f\"Fmax = {fmax:.4f}\")\n",
836
+ "print(f\"Thr. = {thr:.2f}\")\n",
837
+ "print(f\"AuPRC = {auprc:.4f}\")\n",
838
+ "print(f\"Smin = {smin:.4f}\")\n"
839
+ ]
840
+ },
841
+ {
842
+ "cell_type": "code",
843
+ "execution_count": null,
844
+ "id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
845
+ "metadata": {},
846
+ "outputs": [],
847
+ "source": []
848
+ }
849
+ ],
850
+ "metadata": {
851
+ "kernelspec": {
852
+ "display_name": "Python 3 (ipykernel)",
853
+ "language": "python",
854
+ "name": "python3"
855
+ },
856
+ "language_info": {
857
+ "codemirror_mode": {
858
+ "name": "ipython",
859
+ "version": 3
860
+ },
861
+ "file_extension": ".py",
862
+ "mimetype": "text/x-python",
863
+ "name": "python",
864
+ "nbconvert_exporter": "python",
865
+ "pygments_lexer": "ipython3",
866
+ "version": "3.8.18"
867
+ }
868
+ },
869
+ "nbformat": 4,
870
+ "nbformat_minor": 5
871
+ }
notebooks/keras_models_fix.ipynb ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "39935741-50be-4766-873c-99f3c3f14e55",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "A converter modelos .keras para .h5...\n",
14
+ "\n",
15
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
16
+ "Guardado com sucesso\n",
17
+ "\n",
18
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
19
+ "Guardado com sucesso\n",
20
+ "\n",
21
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
22
+ "Guardado com sucesso\n",
23
+ "\n",
24
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
25
+ "Guardado com sucesso\n",
26
+ "\n",
27
+ "Conversão concluída.\n"
28
+ ]
29
+ }
30
+ ],
31
+ "source": [
32
+ "import os\n",
33
+ "from tensorflow.keras.models import load_model\n",
34
+ "\n",
35
+ "MODELS_DIR = \"models\"\n",
36
+ "\n",
37
+ "# Modelos a converter: (nome original .keras → novo nome .h5)\n",
38
+ "modelos = [\n",
39
+ " (\"protbert_mlp.keras\", \"mlp_protbert.h5\"),\n",
40
+ " (\"protbertbfd_mlp.keras\", \"mlp_protbertbfd.h5\"),\n",
41
+ " (\"esm2_mlp.keras\", \"mlp_esm2.h5\"),\n",
42
+ " (\"modelo_ensemble_stacking.keras\", \"modelo_ensemble_stack.h5\"),\n",
43
+ "]\n",
44
+ "\n",
45
+ "print(\"A converter modelos .keras para .h5...\\n\")\n",
46
+ "\n",
47
+ "for origem, destino in modelos:\n",
48
+ " origem_path = os.path.join(MODELS_DIR, origem)\n",
49
+ " destino_path = os.path.join(MODELS_DIR, destino)\n",
50
+ "\n",
51
+ " if not os.path.exists(origem_path):\n",
52
+ " print(f\"Ficheiro não encontrado: {origem_path}\")\n",
53
+ " continue\n",
54
+ "\n",
55
+ "\n",
56
+ " model = load_model(origem_path, compile=False)\n",
57
+ " model.save(destino_path)\n",
58
+ "\n",
59
+ " print(\"Guardado com sucesso\\n\")\n",
60
+ "\n",
61
+ "print(\"Conversão concluída.\")\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "c0f2a4b7-d3eb-48a7-b97e-213b58b2b2ca",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": []
71
+ }
72
+ ],
73
+ "metadata": {
74
+ "kernelspec": {
75
+ "display_name": "Python 3 (ipykernel)",
76
+ "language": "python",
77
+ "name": "python3"
78
+ },
79
+ "language_info": {
80
+ "codemirror_mode": {
81
+ "name": "ipython",
82
+ "version": 3
83
+ },
84
+ "file_extension": ".py",
85
+ "mimetype": "text/x-python",
86
+ "name": "python",
87
+ "nbconvert_exporter": "python",
88
+ "pygments_lexer": "ipython3",
89
+ "version": "3.8.18"
90
+ }
91
+ },
92
+ "nbformat": 4,
93
+ "nbformat_minor": 5
94
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ tensorflow
5
+ scikit-learn==1.3.2
6
+ numpy
7
+ joblib