Update model loading and prediction methods to use SVM; enhance label encoding and improve variable naming for clarity
Browse files- Data/idmapping_2025_06_24_predictions.txt +5 -5
- src/my_utils.py +21 -10
Data/idmapping_2025_06_24_predictions.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
Sequence_ID,Prediction 1,Prediction 2,Prediction 3,Prediction 4,Prediction 5,Prediction 6
|
| 2 |
-
sp|P0A7V8|RS4_ECOLI,Cytoplasmic (0.
|
| 3 |
-
sp|P0A910|OMPA_ECOLI,OuterMembrane (0.
|
| 4 |
-
sp|P0A6F5|CH60_ECOLI,Cytoplasmic (0.
|
| 5 |
-
sp|P02930|TOLC_ECOLI,OuterMembrane (0.
|
| 6 |
-
tr|Q9L1T3|Q9L1T3_STRCO,CytoplasmicMembrane (0.
|
|
|
|
| 1 |
Sequence_ID,Prediction 1,Prediction 2,Prediction 3,Prediction 4,Prediction 5,Prediction 6
|
| 2 |
+
sp|P0A7V8|RS4_ECOLI,Cytoplasmic (0.9860),CytoplasmicMembrane (0.0081),Periplasmic (0.0029),Extracellular (0.0019),OuterMembrane (0.0007),Cellwall (0.0003)
|
| 3 |
+
sp|P0A910|OMPA_ECOLI,OuterMembrane (0.9870),CytoplasmicMembrane (0.0054),Periplasmic (0.0041),Cellwall (0.0021),Extracellular (0.0012),Cytoplasmic (0.0001)
|
| 4 |
+
sp|P0A6F5|CH60_ECOLI,Cytoplasmic (0.9771),CytoplasmicMembrane (0.0103),Periplasmic (0.0063),Extracellular (0.0037),OuterMembrane (0.0016),Cellwall (0.0010)
|
| 5 |
+
sp|P02930|TOLC_ECOLI,OuterMembrane (0.9771),Periplasmic (0.0076),CytoplasmicMembrane (0.0074),Cellwall (0.0034),Extracellular (0.0030),Cytoplasmic (0.0015)
|
| 6 |
+
tr|Q9L1T3|Q9L1T3_STRCO,CytoplasmicMembrane (0.9944),Extracellular (0.0028),Periplasmic (0.0012),Cytoplasmic (0.0007),Cellwall (0.0005),OuterMembrane (0.0004)
|
src/my_utils.py
CHANGED
|
@@ -725,25 +725,33 @@ def predict_with_esm(fasta_path : str,
|
|
| 725 |
tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
|
| 726 |
|
| 727 |
# Load model
|
| 728 |
-
messagebox.showinfo("Info", "Loading
|
| 729 |
project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 730 |
-
model_path = os.path.join(project_root, 'Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
|
| 732 |
try:
|
| 733 |
predictor = load(model_path)
|
|
|
|
| 734 |
|
| 735 |
except FileNotFoundError:
|
| 736 |
print(f"Error: Could not find the model file '{model_path}'")
|
| 737 |
return
|
| 738 |
|
| 739 |
sequence_ids = list(embeddings.keys())
|
| 740 |
-
|
| 741 |
messagebox.showinfo("Info", "Making predictions...")
|
| 742 |
-
y_pred_proba = predictor.predict_proba(
|
| 743 |
|
| 744 |
# Get class names
|
| 745 |
if hasattr(predictor, 'classes_'):
|
| 746 |
-
class_names =
|
| 747 |
else:
|
| 748 |
class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
|
| 749 |
|
|
@@ -938,25 +946,28 @@ def predict_with_prost(fasta_path: str):
|
|
| 938 |
tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
|
| 939 |
|
| 940 |
# Load model
|
| 941 |
-
messagebox.showinfo("Info", "Loading
|
| 942 |
project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 943 |
-
model_path = os.path.join(project_root, 'Models
|
|
|
|
| 944 |
|
| 945 |
try:
|
| 946 |
predictor = load(model_path)
|
|
|
|
|
|
|
| 947 |
except FileNotFoundError:
|
| 948 |
print(f"Error: Could not find the model file '{model_path}'")
|
| 949 |
return
|
| 950 |
|
| 951 |
sequence_ids = list(embeddings.keys())
|
| 952 |
-
|
| 953 |
|
| 954 |
print("Making predictions...")
|
| 955 |
-
y_pred_proba = predictor.predict_proba(
|
| 956 |
|
| 957 |
# Get class names
|
| 958 |
if hasattr(predictor, 'classes_'):
|
| 959 |
-
class_names =
|
| 960 |
else:
|
| 961 |
class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
|
| 962 |
|
|
|
|
| 725 |
tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
|
| 726 |
|
| 727 |
# Load model
|
| 728 |
+
messagebox.showinfo("Info", "Loading SVM for predictions...")
|
| 729 |
project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 730 |
+
model_path = os.path.join(project_root, 'Models/ESMC-300m_svm.joblib'
|
| 731 |
+
if
|
| 732 |
+
model == 'esmc_300m'
|
| 733 |
+
else 'Models/ESMC-600m_svm.joblib')
|
| 734 |
+
le_path = os.path.join(project_root, 'Models/esm_300m_le_svm.joblib'
|
| 735 |
+
if
|
| 736 |
+
model == 'esmc_300m'
|
| 737 |
+
else 'Models/ESMC-600m_le_svm.joblib')
|
| 738 |
|
| 739 |
try:
|
| 740 |
predictor = load(model_path)
|
| 741 |
+
le: LabelEncoder = load(le_path)
|
| 742 |
|
| 743 |
except FileNotFoundError:
|
| 744 |
print(f"Error: Could not find the model file '{model_path}'")
|
| 745 |
return
|
| 746 |
|
| 747 |
sequence_ids = list(embeddings.keys())
|
| 748 |
+
x = np.array(list(embeddings.values()))# type: ignore
|
| 749 |
messagebox.showinfo("Info", "Making predictions...")
|
| 750 |
+
y_pred_proba = predictor.predict_proba(x)
|
| 751 |
|
| 752 |
# Get class names
|
| 753 |
if hasattr(predictor, 'classes_'):
|
| 754 |
+
class_names = le.inverse_transform(predictor.classes_)
|
| 755 |
else:
|
| 756 |
class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
|
| 757 |
|
|
|
|
| 946 |
tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
|
| 947 |
|
| 948 |
# Load model
|
| 949 |
+
messagebox.showinfo("Info", "Loading SVM model for predictions...")
|
| 950 |
project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 951 |
+
model_path = os.path.join(project_root, 'Models/Prost T5_svm.joblib')
|
| 952 |
+
le_path = os.path.join(project_root, 'Models/Prost T5_le_svm.joblib')
|
| 953 |
|
| 954 |
try:
|
| 955 |
predictor = load(model_path)
|
| 956 |
+
le : LabelEncoder = load(le_path)
|
| 957 |
+
|
| 958 |
except FileNotFoundError:
|
| 959 |
print(f"Error: Could not find the model file '{model_path}'")
|
| 960 |
return
|
| 961 |
|
| 962 |
sequence_ids = list(embeddings.keys())
|
| 963 |
+
x = np.array(list(embeddings.values())) #type: ignore
|
| 964 |
|
| 965 |
print("Making predictions...")
|
| 966 |
+
y_pred_proba = predictor.predict_proba(x)
|
| 967 |
|
| 968 |
# Get class names
|
| 969 |
if hasattr(predictor, 'classes_'):
|
| 970 |
+
class_names = le.inverse_transform(predictor.classes_)
|
| 971 |
else:
|
| 972 |
class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
|
| 973 |
|