jpuglia commited on
Commit
572d8c1
·
1 Parent(s): b816ad0

Update model evaluation metrics and classification reports for improved accuracy

Browse files
Classification_Reports/ESMC-300m - Random Forest_classification_report.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
2
+ precision,0.6875,0.9736842105263158,0.9544658493870403,0.823321554770318,0.9058823529411765,0.7301587301587301,0.9409844982322546,0.8458354496305969,0.9431482190855852
3
+ recall,0.7096774193548387,0.9907949790794979,0.8384615384615385,0.8859315589353612,0.8415300546448088,0.8625,0.9409844982322546,0.8548159250793409,0.9409844982322546
4
+ f1-score,0.6984126984126984,0.9821650767316467,0.8927108927108927,0.8534798534798534,0.8725212464589235,0.7908309455587392,0.9409844982322546,0.8483534522254591,0.9409727898172672
5
+ support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9409844982322546,3677.0,3677.0
Classification_Reports/ESMC-300m - Support Vector Machine_classification_report.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
2
+ precision,0.8260869565217391,0.9794238683127572,0.9726247987117552,0.9003831417624522,0.9408284023668639,0.815028901734104,0.9621974435681262,0.9057293449016118,0.9622014817178194
3
+ recall,0.6129032258064516,0.99581589958159,0.9292307692307692,0.8935361216730038,0.8688524590163934,0.88125,0.9621974435681262,0.8635980792180346,0.9621974435681262
4
+ f1-score,0.7037037037037037,0.9875518672199171,0.950432730133753,0.8969465648854962,0.9034090909090909,0.8468468468468469,0.9621974435681262,0.8814818006164681,0.961806189217868
5
+ support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9621974435681262,3677.0,3677.0
Classification_Reports/ESMC-600m - Random Forest_classification_report.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
2
+ precision,0.7307692307692307,0.9733388022969647,0.975736568457539,0.8404255319148937,0.92,0.7821229050279329,0.9499592058743541,0.8703988397444268,0.9512357717810928
3
+ recall,0.6129032258064516,0.992887029288703,0.8661538461538462,0.9011406844106464,0.8797814207650273,0.875,0.9499592058743541,0.8546443677374458,0.9499592058743541
4
+ f1-score,0.6666666666666666,0.9830157415078707,0.9176854115729421,0.8697247706422019,0.8994413407821229,0.8259587020648967,0.9499592058743541,0.8604154388727835,0.9497031761667939
5
+ support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9499592058743541,3677.0,3677.0
Classification_Reports/ESMC-600m - Support Vector Machine_classification_report.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
2
+ precision,0.8571428571428571,0.9782072368421053,0.9587301587301588,0.8897338403041825,0.9470588235294117,0.8385093167701864,0.9602937177046506,0.9115637055531502,0.9597863973858514
3
+ recall,0.5806451612903226,0.9953974895397489,0.9292307692307692,0.8897338403041825,0.8797814207650273,0.84375,0.9602937177046506,0.8530897801883417,0.9602937177046506
4
+ f1-score,0.6923076923076923,0.9867274989630859,0.94375,0.8897338403041825,0.9121813031161473,0.8411214953271028,0.9602937177046506,0.8776369716697019,0.9596645033195284
5
+ support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9602937177046506,3677.0,3677.0
Classification_Reports/Prost T5 - Random Forest_classification_report.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
2
+ precision,0.8181818181818182,0.9783783783783784,0.9553429027113237,0.7916666666666666,0.9523809523809523,0.8391608391608392,0.9510470492249116,0.8891852595799964,0.9522492872817795
3
+ recall,0.5806451612903226,0.9845188284518829,0.9215384615384615,0.9391634980988594,0.8743169398907104,0.75,0.9510470492249116,0.841697148211706,0.9510470492249116
4
+ f1-score,0.6792452830188679,0.9814389989572472,0.9381362568519969,0.8591304347826086,0.9116809116809117,0.7920792079207921,0.9510470492249116,0.8602851822020706,0.9507767100048853
5
+ support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9510470492249116,3677.0,3677.0
Classification_Reports/Prost T5 - Support Vector Machine_classification_report.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,Cellwall,Cytoplasmic,CytoplasmicMembrane,Extracellular,OuterMembrane,Periplasmic,accuracy,macro avg,weighted avg
2
+ precision,0.9,0.9764948453608248,0.9648562300319489,0.8868613138686131,0.9479768786127167,0.8301886792452831,0.9597497960293717,0.917729657853231,0.9595957881278095
3
+ recall,0.5806451612903226,0.9907949790794979,0.9292307692307692,0.9239543726235742,0.8961748633879781,0.825,0.9597497960293717,0.8576333576020237,0.9597497960293717
4
+ f1-score,0.7058823529411765,0.9835929387331257,0.9467084639498433,0.9050279329608939,0.9213483146067416,0.8275862068965517,0.9597497960293717,0.8816910350147221,0.959225689183014
5
+ support,31.0,2390.0,650.0,263.0,183.0,160.0,0.9597497960293717,3677.0,3677.0
Data/evaluations.csv CHANGED
@@ -1,7 +1,7 @@
1
  model,Accuracy,Recall,Precision,F1
2
- Prost T5_rf,0.9494152841990753,0.9494152841990753,0.9500906030394936,0.9487261816656973
3
  Prost T5_svm,0.9597497960293717,0.9597497960293717,0.9595957881278095,0.959225689183014
4
- ESMC-300m_rf,0.939896654881697,0.939896654881697,0.9410635663803479,0.9399078225424956
5
  ESMC-300m_svm,0.9621974435681262,0.9621974435681262,0.9622014817178194,0.961806189217868
6
- ESMC-600m_rf,0.9472395974979603,0.9472395974979603,0.9471989241244075,0.9464063102910955
7
  ESMC-600m_svm,0.9602937177046506,0.9602937177046506,0.9597863973858514,0.9596645033195284
 
1
  model,Accuracy,Recall,Precision,F1
2
+ Prost T5_rf,0.9510470492249116,0.9510470492249116,0.9522492872817795,0.9507767100048853
3
  Prost T5_svm,0.9597497960293717,0.9597497960293717,0.9595957881278095,0.959225689183014
4
+ ESMC-300m_rf,0.9409844982322546,0.9409844982322546,0.9431482190855852,0.9409727898172672
5
  ESMC-300m_svm,0.9621974435681262,0.9621974435681262,0.9622014817178194,0.961806189217868
6
+ ESMC-600m_rf,0.9499592058743541,0.9499592058743541,0.9512357717810928,0.9497031761667939
7
  ESMC-600m_svm,0.9602937177046506,0.9602937177046506,0.9597863973858514,0.9596645033195284
Models/ESMC-300m_rf.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1314aee5d738f8ad952773301e6bcecab06f36a3dc84f3e3bbd62779ebef2b64
3
- size 18049497
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0395f0e953b36fdb2ab140c5dd1b957759a9932a00037a76bfd26150ffdedac1
3
+ size 10191625
Models/ESMC-300m_svm.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd3e89ce2d691e60852c7660e8a52209d93dbe5d4d2fd9723faff902aad34c7d
3
  size 18294469
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2002c8a06e4fcf8ead2098ac214eb78327730bf9095e32e324245a9b9a3fdbc6
3
  size 18294469
Models/ESMC-600m_rf.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:89a1bbdfe47decc8acdb934f7c79297b9cd842e63d5567c39d058ce8ef4ebfb0
3
- size 9024153
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c6f14ba2ad661fa7878a69b107280df0b82e762bc47215496d22f578974f606
3
+ size 18049497
Models/ESMC-600m_svm.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4e4d8f22fac3eecca31048d054bbd097219d2e345b05b5dbc5c98ad9d259e40
3
  size 22787493
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20b804e6ef0c58a22af1f5c68e10221f8da4d1745e88404df9fe8e3048674ae6
3
  size 22787493
Models/Prost T5_rf.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fa2d7e7bb0d6000314f955b77a27ffa5fdf9fb8492afc4714e952dab1d722cf
3
- size 4842553
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4ed73692393a2a87059379e3f4d07aa54d3aed5659394a79419f9b6d9837179
3
+ size 10359641
Models/Prost T5_svm.joblib CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:184f44563c41fc75ab38769f13e1d95488b1bee2767bb9b8cd80a62cd8e826a6
3
  size 18267605
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a3000d817cbf6466de239488ea378b6438d15fba03e7e102e0de5508351a54d
3
  size 18267605
notebooks/04_Training.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a53df923a16509ac44348067c80cfcd6360095c9fa72d278c905cd2b52f3bf0b
3
- size 706158
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b903593cd4834f92401fa1423f8d5fd103e9c8cc961554e8f9cd883c4405b0aa
3
+ size 8532
src/my_utils.py CHANGED
@@ -40,6 +40,7 @@ from tqdm import tqdm
40
  # Visualization libraries
41
  import seaborn as sns
42
  import matplotlib.pyplot as plt
 
43
 
44
  from esm.models.esmc import ESMC
45
  from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
@@ -49,6 +50,11 @@ from joblib import load
49
 
50
  import torch
51
 
 
 
 
 
 
52
  def load_emb(path: str, acc: list[str]) -> np.ndarray:
53
  """
54
  Loads and processes embedding files from a specified directory for a list of accession identifiers.
@@ -90,7 +96,7 @@ def load_emb(path: str, acc: list[str]) -> np.ndarray:
90
 
91
  return np.vstack(x)
92
 
93
- def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
94
 
95
  """
96
  Plot a confusion matrix for the given true and predicted labels.
@@ -105,7 +111,7 @@ def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
105
  normalize = 'pred')
106
 
107
  class_names = list(np.unique(y_true))
108
- plt.figure(figsize=(10, 10))
109
  sns.heatmap(cm, annot=True, fmt='.2f', cmap='Greys',
110
  xticklabels=class_names, yticklabels=class_names)
111
 
@@ -115,6 +121,8 @@ def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
115
  plt.tight_layout()
116
  plt.show()
117
 
 
 
118
  def plot_umap(x: np.ndarray, y: np.ndarray, title: str) -> None:
119
  """
120
  Plots a 2D UMAP projection of high-dimensional data with class labels.
@@ -300,6 +308,13 @@ def train_rf(title: str,
300
  y_pred = classifier.predict(x_test)
301
 
302
  evaluation = evaluate(classifier, x_test, y_test)
 
 
 
 
 
 
 
303
 
304
  print(classification_report(y_test,
305
  y_pred,
@@ -359,7 +374,12 @@ def train_svm(title: str, x: np.ndarray, y: np.ndarray, params: dict) -> tuple[P
359
  confusion(title=title, y_true=y_test_str, y_pred=y_pred_str)
360
 
361
 
362
- print(classification_report(y_test, y_pred, zero_division=0, target_names = le.classes_))
 
 
 
 
 
363
 
364
  return pipeline, evaluation, le
365
 
@@ -466,6 +486,7 @@ def randomSearch(x: np.ndarray, y: np.ndarray) -> dict: #type: ignore
466
 
467
  print('Best Params')
468
  pprint(rf_random.best_params_)
 
469
 
470
  return rf_random.best_params_
471
 
 
40
  # Visualization libraries
41
  import seaborn as sns
42
  import matplotlib.pyplot as plt
43
+ from matplotlib.figure import Figure
44
 
45
  from esm.models.esmc import ESMC
46
  from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
 
50
 
51
  import torch
52
 
53
+ import sys
54
+ import os
55
+
56
+ project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
57
+
58
  def load_emb(path: str, acc: list[str]) -> np.ndarray:
59
  """
60
  Loads and processes embedding files from a specified directory for a list of accession identifiers.
 
96
 
97
  return np.vstack(x)
98
 
99
+ def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> Figure:
100
 
101
  """
102
  Plot a confusion matrix for the given true and predicted labels.
 
111
  normalize = 'pred')
112
 
113
  class_names = list(np.unique(y_true))
114
+ fig = plt.figure(figsize=(10, 10))
115
  sns.heatmap(cm, annot=True, fmt='.2f', cmap='Greys',
116
  xticklabels=class_names, yticklabels=class_names)
117
 
 
121
  plt.tight_layout()
122
  plt.show()
123
 
124
+ return fig
125
+
126
  def plot_umap(x: np.ndarray, y: np.ndarray, title: str) -> None:
127
  """
128
  Plots a 2D UMAP projection of high-dimensional data with class labels.
 
308
  y_pred = classifier.predict(x_test)
309
 
310
  evaluation = evaluate(classifier, x_test, y_test)
311
+
312
+ classification = classification_report(y_test,
313
+ y_pred,
314
+ zero_division=0,
315
+ target_names = le.classes_,
316
+ output_dict=True)
317
+ pd.DataFrame(classification).to_csv(os.path.join(project_root, 'Classification_Reports', f'{title}_classification_report.csv'), index=True)
318
 
319
  print(classification_report(y_test,
320
  y_pred,
 
374
  confusion(title=title, y_true=y_test_str, y_pred=y_pred_str)
375
 
376
 
377
+ classification = classification_report(y_test,
378
+ y_pred,
379
+ zero_division=0,
380
+ target_names = le.classes_,
381
+ output_dict=True)
382
+ pd.DataFrame(classification).to_csv(os.path.join(project_root, 'Classification_Reports', f'{title}_classification_report.csv'), index=True)
383
 
384
  return pipeline, evaluation, le
385
 
 
486
 
487
  print('Best Params')
488
  pprint(rf_random.best_params_)
489
+
490
 
491
  return rf_random.best_params_
492