jpuglia commited on
Commit
96b4d50
·
1 Parent(s): 4df44a7

Update model loading and prediction methods to use SVM; enhance label encoding and improve variable naming for clarity

Browse files
Data/idmapping_2025_06_24_predictions.txt CHANGED
@@ -1,6 +1,6 @@
1
  Sequence_ID,Prediction 1,Prediction 2,Prediction 3,Prediction 4,Prediction 5,Prediction 6
2
- sp|P0A7V8|RS4_ECOLI,Cytoplasmic (0.9020),CytoplasmicMembrane (0.0480),Periplasmic (0.0247),Extracellular (0.0184),OuterMembrane (0.0061),Cellwall (0.0006)
3
- sp|P0A910|OMPA_ECOLI,OuterMembrane (0.9400),Extracellular (0.0251),Periplasmic (0.0177),CytoplasmicMembrane (0.0124),Cytoplasmic (0.0028),Cellwall (0.0019)
4
- sp|P0A6F5|CH60_ECOLI,Cytoplasmic (0.9935),CytoplasmicMembrane (0.0059),OuterMembrane (0.0004),Periplasmic (0.0003),Extracellular (0.0000),Cellwall (0.0000)
5
- sp|P02930|TOLC_ECOLI,OuterMembrane (0.9483),CytoplasmicMembrane (0.0166),Extracellular (0.0154),Periplasmic (0.0150),Cytoplasmic (0.0031),Cellwall (0.0016)
6
- tr|Q9L1T3|Q9L1T3_STRCO,CytoplasmicMembrane (0.5872),Cytoplasmic (0.1784),Extracellular (0.1599),Periplasmic (0.0445),OuterMembrane (0.0246),Cellwall (0.0053)
 
1
  Sequence_ID,Prediction 1,Prediction 2,Prediction 3,Prediction 4,Prediction 5,Prediction 6
2
+ sp|P0A7V8|RS4_ECOLI,Cytoplasmic (0.9860),CytoplasmicMembrane (0.0081),Periplasmic (0.0029),Extracellular (0.0019),OuterMembrane (0.0007),Cellwall (0.0003)
3
+ sp|P0A910|OMPA_ECOLI,OuterMembrane (0.9870),CytoplasmicMembrane (0.0054),Periplasmic (0.0041),Cellwall (0.0021),Extracellular (0.0012),Cytoplasmic (0.0001)
4
+ sp|P0A6F5|CH60_ECOLI,Cytoplasmic (0.9771),CytoplasmicMembrane (0.0103),Periplasmic (0.0063),Extracellular (0.0037),OuterMembrane (0.0016),Cellwall (0.0010)
5
+ sp|P02930|TOLC_ECOLI,OuterMembrane (0.9771),Periplasmic (0.0076),CytoplasmicMembrane (0.0074),Cellwall (0.0034),Extracellular (0.0030),Cytoplasmic (0.0015)
6
+ tr|Q9L1T3|Q9L1T3_STRCO,CytoplasmicMembrane (0.9944),Extracellular (0.0028),Periplasmic (0.0012),Cytoplasmic (0.0007),Cellwall (0.0005),OuterMembrane (0.0004)
src/my_utils.py CHANGED
@@ -725,25 +725,33 @@ def predict_with_esm(fasta_path : str,
725
  tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
726
 
727
  # Load model
728
- messagebox.showinfo("Info", "Loading random forest model for predictions...")
729
  project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
730
- model_path = os.path.join(project_root, 'Models', 'rfESM300.joblib' if model == 'esmc_300m' else 'rfESM600.joblib')
 
 
 
 
 
 
 
731
 
732
  try:
733
  predictor = load(model_path)
 
734
 
735
  except FileNotFoundError:
736
  print(f"Error: Could not find the model file '{model_path}'")
737
  return
738
 
739
  sequence_ids = list(embeddings.keys())
740
- X = np.array(list(embeddings.values()))# type: ignore
741
  messagebox.showinfo("Info", "Making predictions...")
742
- y_pred_proba = predictor.predict_proba(X)
743
 
744
  # Get class names
745
  if hasattr(predictor, 'classes_'):
746
- class_names = [str(cls) for cls in predictor.classes_]
747
  else:
748
  class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
749
 
@@ -938,25 +946,28 @@ def predict_with_prost(fasta_path: str):
938
  tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
939
 
940
  # Load model
941
- messagebox.showinfo("Info", "Loading random forest model for predictions...")
942
  project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
943
- model_path = os.path.join(project_root, 'Models', 'rfProst.joblib')
 
944
 
945
  try:
946
  predictor = load(model_path)
 
 
947
  except FileNotFoundError:
948
  print(f"Error: Could not find the model file '{model_path}'")
949
  return
950
 
951
  sequence_ids = list(embeddings.keys())
952
- X = np.array(list(embeddings.values())) # type: ignore
953
 
954
  print("Making predictions...")
955
- y_pred_proba = predictor.predict_proba(X)
956
 
957
  # Get class names
958
  if hasattr(predictor, 'classes_'):
959
- class_names = [str(cls) for cls in predictor.classes_]
960
  else:
961
  class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
962
 
 
725
  tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
726
 
727
  # Load model
728
+ messagebox.showinfo("Info", "Loading SVM for predictions...")
729
  project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
730
+ model_path = os.path.join(project_root, 'Models/ESMC-300m_svm.joblib'
731
+ if
732
+ model == 'esmc_300m'
733
+ else 'Models/ESMC-600m_svm.joblib')
734
+ le_path = os.path.join(project_root, 'Models/esm_300m_le_svm.joblib'
735
+ if
736
+ model == 'esmc_300m'
737
+ else 'Models/ESMC-600m_le_svm.joblib')
738
 
739
  try:
740
  predictor = load(model_path)
741
+ le: LabelEncoder = load(le_path)
742
 
743
  except FileNotFoundError:
744
  print(f"Error: Could not find the model file '{model_path}'")
745
  return
746
 
747
  sequence_ids = list(embeddings.keys())
748
+ x = np.array(list(embeddings.values()))# type: ignore
749
  messagebox.showinfo("Info", "Making predictions...")
750
+ y_pred_proba = predictor.predict_proba(x)
751
 
752
  # Get class names
753
  if hasattr(predictor, 'classes_'):
754
+ class_names = le.inverse_transform(predictor.classes_)
755
  else:
756
  class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
757
 
 
946
  tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
947
 
948
  # Load model
949
+ messagebox.showinfo("Info", "Loading SVM model for predictions...")
950
  project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
951
+ model_path = os.path.join(project_root, 'Models/Prost T5_svm.joblib')
952
+ le_path = os.path.join(project_root, 'Models/Prost T5_le_svm.joblib')
953
 
954
  try:
955
  predictor = load(model_path)
956
+ le : LabelEncoder = load(le_path)
957
+
958
  except FileNotFoundError:
959
  print(f"Error: Could not find the model file '{model_path}'")
960
  return
961
 
962
  sequence_ids = list(embeddings.keys())
963
+ x = np.array(list(embeddings.values())) #type: ignore
964
 
965
  print("Making predictions...")
966
+ y_pred_proba = predictor.predict_proba(x)
967
 
968
  # Get class names
969
  if hasattr(predictor, 'classes_'):
970
+ class_names = le.inverse_transform(predictor.classes_)
971
  else:
972
  class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
973