Update model evaluation metrics and classification reports for improved accuracy
Browse files- Classification_Reports/ESMC-300m - Random Forest_classification_report.csv +5 -0
- Classification_Reports/ESMC-300m - Support Vector Machine_classification_report.csv +5 -0
- Classification_Reports/ESMC-600m - Random Forest_classification_report.csv +5 -0
- Classification_Reports/ESMC-600m - Support Vector Machine_classification_report.csv +5 -0
- Classification_Reports/Prost T5 - Random Forest_classification_report.csv +5 -0
- Classification_Reports/Prost T5 - Support Vector Machine_classification_report.csv +5 -0
- Data/evaluations.csv +3 -3
- Models/ESMC-300m_rf.joblib +2 -2
- Models/ESMC-300m_svm.joblib +1 -1
- Models/ESMC-600m_rf.joblib +2 -2
- Models/ESMC-600m_svm.joblib +1 -1
- Models/Prost T5_rf.joblib +2 -2
- Models/Prost T5_svm.joblib +1 -1
- notebooks/04_Training.ipynb +2 -2
- src/my_utils.py +24 -3
Classification_Reports/ESMC-300m - Random Forest_classification_report.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
|
| 2 |
+
precision,0.6875,0.9736842105263158,0.9544658493870403,0.823321554770318,0.9058823529411765,0.7301587301587301,0.9409844982322546,0.8458354496305969,0.9431482190855852
|
| 3 |
+
recall,0.7096774193548387,0.9907949790794979,0.8384615384615385,0.8859315589353612,0.8415300546448088,0.8625,0.9409844982322546,0.8548159250793409,0.9409844982322546
|
| 4 |
+
f1-score,0.6984126984126984,0.9821650767316467,0.8927108927108927,0.8534798534798534,0.8725212464589235,0.7908309455587392,0.9409844982322546,0.8483534522254591,0.9409727898172672
|
| 5 |
+
support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9409844982322546,3677.0,3677.0
|
Classification_Reports/ESMC-300m - Support Vector Machine_classification_report.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
|
| 2 |
+
precision,0.8260869565217391,0.9794238683127572,0.9726247987117552,0.9003831417624522,0.9408284023668639,0.815028901734104,0.9621974435681262,0.9057293449016118,0.9622014817178194
|
| 3 |
+
recall,0.6129032258064516,0.99581589958159,0.9292307692307692,0.8935361216730038,0.8688524590163934,0.88125,0.9621974435681262,0.8635980792180346,0.9621974435681262
|
| 4 |
+
f1-score,0.7037037037037037,0.9875518672199171,0.950432730133753,0.8969465648854962,0.9034090909090909,0.8468468468468469,0.9621974435681262,0.8814818006164681,0.961806189217868
|
| 5 |
+
support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9621974435681262,3677.0,3677.0
|
Classification_Reports/ESMC-600m - Random Forest_classification_report.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
|
| 2 |
+
precision,0.7307692307692307,0.9733388022969647,0.975736568457539,0.8404255319148937,0.92,0.7821229050279329,0.9499592058743541,0.8703988397444268,0.9512357717810928
|
| 3 |
+
recall,0.6129032258064516,0.992887029288703,0.8661538461538462,0.9011406844106464,0.8797814207650273,0.875,0.9499592058743541,0.8546443677374458,0.9499592058743541
|
| 4 |
+
f1-score,0.6666666666666666,0.9830157415078707,0.9176854115729421,0.8697247706422019,0.8994413407821229,0.8259587020648967,0.9499592058743541,0.8604154388727835,0.9497031761667939
|
| 5 |
+
support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9499592058743541,3677.0,3677.0
|
Classification_Reports/ESMC-600m - Support Vector Machine_classification_report.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
|
| 2 |
+
precision,0.8571428571428571,0.9782072368421053,0.9587301587301588,0.8897338403041825,0.9470588235294117,0.8385093167701864,0.9602937177046506,0.9115637055531502,0.9597863973858514
|
| 3 |
+
recall,0.5806451612903226,0.9953974895397489,0.9292307692307692,0.8897338403041825,0.8797814207650273,0.84375,0.9602937177046506,0.8530897801883417,0.9602937177046506
|
| 4 |
+
f1-score,0.6923076923076923,0.9867274989630859,0.94375,0.8897338403041825,0.9121813031161473,0.8411214953271028,0.9602937177046506,0.8776369716697019,0.9596645033195284
|
| 5 |
+
support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9602937177046506,3677.0,3677.0
|
Classification_Reports/Prost T5 - Random Forest_classification_report.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
|
| 2 |
+
precision,0.8181818181818182,0.9783783783783784,0.9553429027113237,0.7916666666666666,0.9523809523809523,0.8391608391608392,0.9510470492249116,0.8891852595799964,0.9522492872817795
|
| 3 |
+
recall,0.5806451612903226,0.9845188284518829,0.9215384615384615,0.9391634980988594,0.8743169398907104,0.75,0.9510470492249116,0.841697148211706,0.9510470492249116
|
| 4 |
+
f1-score,0.6792452830188679,0.9814389989572472,0.9381362568519969,0.8591304347826086,0.9116809116809117,0.7920792079207921,0.9510470492249116,0.8602851822020706,0.9507767100048853
|
| 5 |
+
support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9510470492249116,3677.0,3677.0
|
Classification_Reports/Prost T5 - Support Vector Machine_classification_report.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
|
| 2 |
+
precision,0.9,0.9764948453608248,0.9648562300319489,0.8868613138686131,0.9479768786127167,0.8301886792452831,0.9597497960293717,0.917729657853231,0.9595957881278095
|
| 3 |
+
recall,0.5806451612903226,0.9907949790794979,0.9292307692307692,0.9239543726235742,0.8961748633879781,0.825,0.9597497960293717,0.8576333576020237,0.9597497960293717
|
| 4 |
+
f1-score,0.7058823529411765,0.9835929387331257,0.9467084639498433,0.9050279329608939,0.9213483146067416,0.8275862068965517,0.9597497960293717,0.8816910350147221,0.959225689183014
|
| 5 |
+
support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9597497960293717,3677.0,3677.0
|
Data/evaluations.csv
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
model,Accuracy,Recall,Precision,F1
|
| 2 |
-
Prost T5_rf,0.
|
| 3 |
Prost T5_svm,0.9597497960293717,0.9597497960293717,0.9595957881278095,0.959225689183014
|
| 4 |
-
ESMC-300m_rf,0.
|
| 5 |
ESMC-300m_svm,0.9621974435681262,0.9621974435681262,0.9622014817178194,0.961806189217868
|
| 6 |
-
ESMC-600m_rf,0.
|
| 7 |
ESMC-600m_svm,0.9602937177046506,0.9602937177046506,0.9597863973858514,0.9596645033195284
|
|
|
|
| 1 |
model,Accuracy,Recall,Precision,F1
|
| 2 |
+
Prost T5_rf,0.9510470492249116,0.9510470492249116,0.9522492872817795,0.9507767100048853
|
| 3 |
Prost T5_svm,0.9597497960293717,0.9597497960293717,0.9595957881278095,0.959225689183014
|
| 4 |
+
ESMC-300m_rf,0.9409844982322546,0.9409844982322546,0.9431482190855852,0.9409727898172672
|
| 5 |
ESMC-300m_svm,0.9621974435681262,0.9621974435681262,0.9622014817178194,0.961806189217868
|
| 6 |
+
ESMC-600m_rf,0.9499592058743541,0.9499592058743541,0.9512357717810928,0.9497031761667939
|
| 7 |
ESMC-600m_svm,0.9602937177046506,0.9602937177046506,0.9597863973858514,0.9596645033195284
|
Models/ESMC-300m_rf.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0395f0e953b36fdb2ab140c5dd1b957759a9932a00037a76bfd26150ffdedac1
|
| 3 |
+
size 10191625
|
Models/ESMC-300m_svm.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 18294469
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2002c8a06e4fcf8ead2098ac214eb78327730bf9095e32e324245a9b9a3fdbc6
|
| 3 |
size 18294469
|
Models/ESMC-600m_rf.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c6f14ba2ad661fa7878a69b107280df0b82e762bc47215496d22f578974f606
|
| 3 |
+
size 18049497
|
Models/ESMC-600m_svm.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 22787493
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20b804e6ef0c58a22af1f5c68e10221f8da4d1745e88404df9fe8e3048674ae6
|
| 3 |
size 22787493
|
Models/Prost T5_rf.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4ed73692393a2a87059379e3f4d07aa54d3aed5659394a79419f9b6d9837179
|
| 3 |
+
size 10359641
|
Models/Prost T5_svm.joblib
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 18267605
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a3000d817cbf6466de239488ea378b6438d15fba03e7e102e0de5508351a54d
|
| 3 |
size 18267605
|
notebooks/04_Training.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b903593cd4834f92401fa1423f8d5fd103e9c8cc961554e8f9cd883c4405b0aa
|
| 3 |
+
size 8532
|
src/my_utils.py
CHANGED
|
@@ -40,6 +40,7 @@ from tqdm import tqdm
|
|
| 40 |
# Visualization libraries
|
| 41 |
import seaborn as sns
|
| 42 |
import matplotlib.pyplot as plt
|
|
|
|
| 43 |
|
| 44 |
from esm.models.esmc import ESMC
|
| 45 |
from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
|
|
@@ -49,6 +50,11 @@ from joblib import load
|
|
| 49 |
|
| 50 |
import torch
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def load_emb(path: str, acc: list[str]) -> np.ndarray:
|
| 53 |
"""
|
| 54 |
Loads and processes embedding files from a specified directory for a list of accession identifiers.
|
|
@@ -90,7 +96,7 @@ def load_emb(path: str, acc: list[str]) -> np.ndarray:
|
|
| 90 |
|
| 91 |
return np.vstack(x)
|
| 92 |
|
| 93 |
-
def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) ->
|
| 94 |
|
| 95 |
"""
|
| 96 |
Plot a confusion matrix for the given true and predicted labels.
|
|
@@ -105,7 +111,7 @@ def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
|
| 105 |
normalize = 'pred')
|
| 106 |
|
| 107 |
class_names = list(np.unique(y_true))
|
| 108 |
-
plt.figure(figsize=(10, 10))
|
| 109 |
sns.heatmap(cm, annot=True, fmt='.2f', cmap='Greys',
|
| 110 |
xticklabels=class_names, yticklabels=class_names)
|
| 111 |
|
|
@@ -115,6 +121,8 @@ def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
|
| 115 |
plt.tight_layout()
|
| 116 |
plt.show()
|
| 117 |
|
|
|
|
|
|
|
| 118 |
def plot_umap(x: np.ndarray, y: np.ndarray, title: str) -> None:
|
| 119 |
"""
|
| 120 |
Plots a 2D UMAP projection of high-dimensional data with class labels.
|
|
@@ -300,6 +308,13 @@ def train_rf(title: str,
|
|
| 300 |
y_pred = classifier.predict(x_test)
|
| 301 |
|
| 302 |
evaluation = evaluate(classifier, x_test, y_test)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
print(classification_report(y_test,
|
| 305 |
y_pred,
|
|
@@ -359,7 +374,12 @@ def train_svm(title: str, x: np.ndarray, y: np.ndarray, params: dict) -> tuple[P
|
|
| 359 |
confusion(title=title, y_true=y_test_str, y_pred=y_pred_str)
|
| 360 |
|
| 361 |
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
return pipeline, evaluation, le
|
| 365 |
|
|
@@ -466,6 +486,7 @@ def randomSearch(x: np.ndarray, y: np.ndarray) -> dict: #type: ignore
|
|
| 466 |
|
| 467 |
print('Best Params')
|
| 468 |
pprint(rf_random.best_params_)
|
|
|
|
| 469 |
|
| 470 |
return rf_random.best_params_
|
| 471 |
|
|
|
|
| 40 |
# Visualization libraries
|
| 41 |
import seaborn as sns
|
| 42 |
import matplotlib.pyplot as plt
|
| 43 |
+
from matplotlib.figure import Figure
|
| 44 |
|
| 45 |
from esm.models.esmc import ESMC
|
| 46 |
from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
|
|
|
|
| 50 |
|
| 51 |
import torch
|
| 52 |
|
| 53 |
+
import sys
|
| 54 |
+
import os
|
| 55 |
+
|
| 56 |
+
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
| 57 |
+
|
| 58 |
def load_emb(path: str, acc: list[str]) -> np.ndarray:
|
| 59 |
"""
|
| 60 |
Loads and processes embedding files from a specified directory for a list of accession identifiers.
|
|
|
|
| 96 |
|
| 97 |
return np.vstack(x)
|
| 98 |
|
| 99 |
+
def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> Figure:
|
| 100 |
|
| 101 |
"""
|
| 102 |
Plot a confusion matrix for the given true and predicted labels.
|
|
|
|
| 111 |
normalize = 'pred')
|
| 112 |
|
| 113 |
class_names = list(np.unique(y_true))
|
| 114 |
+
fig = plt.figure(figsize=(10, 10))
|
| 115 |
sns.heatmap(cm, annot=True, fmt='.2f', cmap='Greys',
|
| 116 |
xticklabels=class_names, yticklabels=class_names)
|
| 117 |
|
|
|
|
| 121 |
plt.tight_layout()
|
| 122 |
plt.show()
|
| 123 |
|
| 124 |
+
return fig
|
| 125 |
+
|
| 126 |
def plot_umap(x: np.ndarray, y: np.ndarray, title: str) -> None:
|
| 127 |
"""
|
| 128 |
Plots a 2D UMAP projection of high-dimensional data with class labels.
|
|
|
|
| 308 |
y_pred = classifier.predict(x_test)
|
| 309 |
|
| 310 |
evaluation = evaluate(classifier, x_test, y_test)
|
| 311 |
+
|
| 312 |
+
classification = classification_report(y_test,
|
| 313 |
+
y_pred,
|
| 314 |
+
zero_division=0,
|
| 315 |
+
target_names = le.classes_,
|
| 316 |
+
output_dict=True)
|
| 317 |
+
pd.DataFrame(classification).to_csv(os.path.join(project_root, 'Classification_Reports', f'{title}_classification_report.csv'), index=True)
|
| 318 |
|
| 319 |
print(classification_report(y_test,
|
| 320 |
y_pred,
|
|
|
|
| 374 |
confusion(title=title, y_true=y_test_str, y_pred=y_pred_str)
|
| 375 |
|
| 376 |
|
| 377 |
+
classification = classification_report(y_test,
|
| 378 |
+
y_pred,
|
| 379 |
+
zero_division=0,
|
| 380 |
+
target_names = le.classes_,
|
| 381 |
+
output_dict=True)
|
| 382 |
+
pd.DataFrame(classification).to_csv(os.path.join(project_root, 'Classification_Reports', f'{title}_classification_report.csv'), index=True)
|
| 383 |
|
| 384 |
return pipeline, evaluation, le
|
| 385 |
|
|
|
|
| 486 |
|
| 487 |
print('Best Params')
|
| 488 |
pprint(rf_random.best_params_)
|
| 489 |
+
|
| 490 |
|
| 491 |
return rf_random.best_params_
|
| 492 |
|