Enhance CLI functionality: Implement file loading and prediction features with user feedback in Tkinter GUI
Browse files- cli.py +64 -25
- src/my_utils.py +117 -119
cli.py
CHANGED
|
@@ -1,19 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import tkinter as tk
|
| 2 |
-
from tkinter import Menu
|
| 3 |
from src.my_utils import predict_with_prost
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def menu():
|
| 7 |
"""
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
- A "Help" menu with options for welcome and about dialogs.
|
| 13 |
-
- Two buttons below the menu: one for loading a FASTA file (triggers `predict_with_prost`), and one for exiting the application.
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
# root window
|
| 19 |
root = tk.Tk()
|
|
@@ -26,19 +76,10 @@ def menu():
|
|
| 26 |
|
| 27 |
# create the file_menu
|
| 28 |
file_menu = Menu(menubar, tearoff=0)
|
| 29 |
-
file_menu.add_command(label='
|
| 30 |
-
file_menu.add_command(label='
|
| 31 |
-
file_menu.add_command(label='Close')
|
| 32 |
-
file_menu.add_separator()
|
| 33 |
-
|
| 34 |
-
sub_menu = Menu(file_menu, tearoff=0)
|
| 35 |
-
sub_menu.add_command(label='Keyboard Shortcuts')
|
| 36 |
-
sub_menu.add_command(label='Color Themes')
|
| 37 |
-
|
| 38 |
-
file_menu.add_cascade(label="Preferences", menu=sub_menu)
|
| 39 |
file_menu.add_separator()
|
| 40 |
-
|
| 41 |
-
menubar.add_cascade(label="File", menu=file_menu, underline=0)
|
| 42 |
|
| 43 |
# help menu
|
| 44 |
help_menu = Menu(menubar, tearoff=0)
|
|
@@ -50,17 +91,15 @@ def menu():
|
|
| 50 |
# Add Buttons Below Menu
|
| 51 |
# =========================
|
| 52 |
|
| 53 |
-
btn_prost = tk.Button(root, text="Predict with Prost", command=
|
| 54 |
btn_prost.pack(pady=5)
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
btn_exit = tk.Button(root, text="Exit", command=root.quit)
|
| 60 |
btn_exit.pack(pady=5)
|
| 61 |
|
| 62 |
root.mainloop()
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
menu()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Protein Location Predictor CLI
|
| 3 |
+
This module provides a Tkinter-based GUI for loading FASTA files
|
| 4 |
+
and running protein location prediction tools.
|
| 5 |
+
"""
|
| 6 |
import tkinter as tk
|
| 7 |
+
from tkinter import Menu, filedialog, messagebox
|
| 8 |
from src.my_utils import predict_with_prost
|
| 9 |
|
| 10 |
|
| 11 |
+
FASTA_FILE_PATH = None # Global or instance variable
|
| 12 |
+
|
| 13 |
+
def load_fasta_file():
|
| 14 |
+
"""
|
| 15 |
+
Opens a file dialog for the user to select a FASTA file and stores the selected file path in a global variable.
|
| 16 |
+
|
| 17 |
+
If a file is selected, displays an information message with the file path.
|
| 18 |
+
If no file is selected, displays a warning message.
|
| 19 |
+
|
| 20 |
+
Uses:
|
| 21 |
+
- filedialog.askopenfilename for file selection.
|
| 22 |
+
- messagebox.showinfo and messagebox.showwarning for user feedback.
|
| 23 |
+
|
| 24 |
+
Global Variables:
|
| 25 |
+
FASTA_FILE_PATH (str): Path to the selected FASTA file.
|
| 26 |
+
"""
|
| 27 |
+
global FASTA_FILE_PATH # pylint: disable=global-statement
|
| 28 |
+
FASTA_FILE_PATH = filedialog.askopenfilename(
|
| 29 |
+
filetypes=[("FASTA files", "*.fasta *.fa")],
|
| 30 |
+
title="Select a FASTA file"
|
| 31 |
+
)
|
| 32 |
+
if FASTA_FILE_PATH:
|
| 33 |
+
messagebox.showinfo("File Loaded", f"Loaded file:\n{FASTA_FILE_PATH}")
|
| 34 |
+
else:
|
| 35 |
+
messagebox.showwarning("No file", "No file was selected.")
|
| 36 |
+
|
| 37 |
+
def run_prediction():
|
| 38 |
+
"""
|
| 39 |
+
Runs the protein location prediction process.
|
| 40 |
+
|
| 41 |
+
Checks if a FASTA file path is provided. If not, displays an error message to the user.
|
| 42 |
+
If a FASTA file is loaded, proceeds to run the prediction using the PROST model.
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
Shows a message box error if no FASTA file is loaded.
|
| 46 |
+
"""
|
| 47 |
+
if not FASTA_FILE_PATH:
|
| 48 |
+
messagebox.showerror("Error", "Please load a FASTA file first.")
|
| 49 |
+
return
|
| 50 |
+
predict_with_prost(FASTA_FILE_PATH)
|
| 51 |
+
|
| 52 |
def menu():
|
| 53 |
"""
|
| 54 |
+
Displays the main GUI menu for the Protein Tools application.
|
| 55 |
|
| 56 |
+
This function creates a Tkinter window with a menu bar containing 'File' and 'Help' menus,
|
| 57 |
+
and buttons for running protein prediction tools and exiting the application.
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
Menus:
|
| 60 |
+
- File: Options to load a FASTA file or close the application.
|
| 61 |
+
- Help: Options for welcome information and about dialog.
|
| 62 |
+
|
| 63 |
+
Buttons:
|
| 64 |
+
- Predict with Prost: Runs the Prost prediction tool.
|
| 65 |
+
- Predict with ESM C: Placeholder for ESM prediction functionality (not yet implemented).
|
| 66 |
+
- Exit: Closes the application.
|
| 67 |
"""
|
| 68 |
# root window
|
| 69 |
root = tk.Tk()
|
|
|
|
| 76 |
|
| 77 |
# create the file_menu
|
| 78 |
file_menu = Menu(menubar, tearoff=0)
|
| 79 |
+
file_menu.add_command(label='Load FASTA', command=load_fasta_file)
|
| 80 |
+
file_menu.add_command(label='Close', command=root.quit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
file_menu.add_separator()
|
| 82 |
+
menubar.add_cascade(label="File", menu=file_menu, underline=0)
|
|
|
|
| 83 |
|
| 84 |
# help menu
|
| 85 |
help_menu = Menu(menubar, tearoff=0)
|
|
|
|
| 91 |
# Add Buttons Below Menu
|
| 92 |
# =========================
|
| 93 |
|
| 94 |
+
btn_prost = tk.Button(root, text="Predict with Prost", command=run_prediction)
|
| 95 |
btn_prost.pack(pady=5)
|
| 96 |
|
| 97 |
+
btn_esm = tk.Button(root, text="Predict with ESM C") #type: ignore
|
| 98 |
+
btn_esm.pack(pady=5)
|
| 99 |
|
| 100 |
btn_exit = tk.Button(root, text="Exit", command=root.quit)
|
| 101 |
btn_exit.pack(pady=5)
|
| 102 |
|
| 103 |
root.mainloop()
|
| 104 |
|
|
|
|
|
|
|
| 105 |
menu()
|
src/my_utils.py
CHANGED
|
@@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
| 7 |
from urllib.error import HTTPError
|
| 8 |
from typing import Literal, Optional
|
| 9 |
import tkinter as tk
|
| 10 |
-
from tkinter import filedialog
|
| 11 |
|
| 12 |
|
| 13 |
import pandas as pd
|
|
@@ -39,6 +39,7 @@ import plotly.express as px
|
|
| 39 |
from esm.models.esmc import ESMC
|
| 40 |
from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
|
| 41 |
from transformers import T5Tokenizer, T5EncoderModel, PreTrainedModel
|
|
|
|
| 42 |
|
| 43 |
from joblib import load
|
| 44 |
|
|
@@ -503,10 +504,9 @@ def _fetch_sequence_for_row(idx, row):
|
|
| 503 |
|
| 504 |
return idx, sequence
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
def fetch_sequences_for_dataframe(df: pd.DataFrame, batch_size: Optional[int] = None, max_workers: int = 5) -> pd.DataFrame:
|
| 510 |
"""
|
| 511 |
Add a 'sequence' column to the dataframe by fetching sequences from
|
| 512 |
SwissProt or RefSeq based on available IDs, with parallel execution and a progress bar.
|
|
@@ -553,69 +553,74 @@ def fetch_sequences_for_dataframe(df: pd.DataFrame, batch_size: Optional[int] =
|
|
| 553 |
f"({round(success_count/total_rows*100, 2)}%)")
|
| 554 |
return result_df
|
| 555 |
|
| 556 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
"""
|
| 559 |
-
|
| 560 |
Args:
|
| 561 |
-
model:
|
| 562 |
-
|
|
|
|
|
|
|
| 563 |
Returns:
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
"""
|
| 566 |
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
protein = ESMProtein(sequence=sequence)
|
| 570 |
-
protein_tensor = client.encode(protein)
|
| 571 |
|
| 572 |
if isinstance(protein_tensor, ESMProteinError):
|
| 573 |
-
|
| 574 |
raise protein_tensor
|
|
|
|
|
|
|
| 575 |
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
path: Directory to save the embeddings.
|
| 594 |
-
"""
|
| 595 |
-
|
| 596 |
-
assert len(seq_list) == len(id_list), "Sequence and ID lists must be the same length."
|
| 597 |
-
os.makedirs(path, exist_ok=True)
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
-
if len(emb_array.shape) == 3:
|
| 607 |
-
emb_array = emb_array.squeeze(axis=0).mean(axis=0)
|
| 608 |
-
elif len(emb_array.shape) == 2:
|
| 609 |
-
emb_array = emb_array.mean(axis=0)
|
| 610 |
|
| 611 |
-
np.save(os.path.join(path, f"{acc}.npy"), emb_array)
|
| 612 |
|
| 613 |
-
except ESMProteinError as e:
|
| 614 |
-
print(f"Error processing {acc}: {e}")
|
| 615 |
|
| 616 |
-
if i % 100 == 0:
|
| 617 |
-
gc.collect()
|
| 618 |
-
torch.cuda.empty_cache()
|
| 619 |
|
| 620 |
def prost_embed_sequence(seq : str,
|
| 621 |
acc : str,
|
|
@@ -732,103 +737,96 @@ def save_predictions_to_txt(predictions_dict: dict[str, tuple[list[str], list[fl
|
|
| 732 |
|
| 733 |
f.write(f"{seq_id},{pred_line}\n")
|
| 734 |
|
| 735 |
-
def predict_with_prost():
|
| 736 |
"""
|
| 737 |
-
Function to
|
| 738 |
"""
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
fasta_path : str = filedialog.askopenfilename(
|
| 743 |
-
title="Select a FASTA file",
|
| 744 |
-
filetypes=[("FASTA files", "*.fasta *.fa")],
|
| 745 |
-
initialdir="."
|
| 746 |
-
)
|
| 747 |
-
|
| 748 |
-
if not fasta_path:
|
| 749 |
-
print("No file selected.")
|
| 750 |
return
|
| 751 |
|
| 752 |
-
#
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
initialdir="."
|
| 756 |
-
)
|
| 757 |
|
|
|
|
| 758 |
if not output_dir:
|
| 759 |
-
print("No output directory selected.")
|
| 760 |
return
|
| 761 |
|
| 762 |
result = fasta_to_seq(fasta_path)
|
| 763 |
-
|
| 764 |
if result is None:
|
| 765 |
-
|
| 766 |
-
return
|
| 767 |
-
else:
|
| 768 |
-
sequences, ids = result
|
| 769 |
-
print(f"Sequences loaded from {fasta_path}: {len(sequences)} sequences found.")
|
| 770 |
-
print("Embedding sequences using ProstT5...")
|
| 771 |
|
| 772 |
-
|
| 773 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
-
print("Loading pre-trained SVM model for prediction...")
|
| 787 |
try:
|
| 788 |
-
predictor = load(
|
| 789 |
except FileNotFoundError:
|
| 790 |
-
print("Error: Could not find the model file '
|
| 791 |
-
print("Please check the path to your trained model.")
|
| 792 |
return
|
| 793 |
|
| 794 |
sequence_ids = list(embeddings.keys())
|
| 795 |
-
X = np.array(list(embeddings.values()))
|
|
|
|
| 796 |
print("Making predictions...")
|
| 797 |
y_pred_proba = predictor.predict_proba(X)
|
| 798 |
-
|
| 799 |
-
# Get class names
|
| 800 |
if hasattr(predictor, 'classes_'):
|
| 801 |
-
class_names = predictor.classes_
|
| 802 |
else:
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
class_names = [f"Class_{i}" for i in range(n_classes)]
|
| 806 |
-
|
| 807 |
-
# Convert class names to strings if they aren't already
|
| 808 |
-
class_names = [str(cls) for cls in class_names]
|
| 809 |
-
|
| 810 |
-
# Create predictions dictionary
|
| 811 |
predictions_dict = {}
|
| 812 |
for i, seq_id in enumerate(sequence_ids):
|
| 813 |
-
|
| 814 |
-
class_prob_pairs = sorted(zip(class_names, probabilities), key=lambda x: x[1], reverse=True)
|
| 815 |
sorted_classes, sorted_probs = zip(*class_prob_pairs)
|
| 816 |
predictions_dict[seq_id] = (list(sorted_classes), list(sorted_probs))
|
| 817 |
-
|
| 818 |
-
#
|
| 819 |
input_filename = os.path.splitext(os.path.basename(fasta_path))[0]
|
| 820 |
output_file = os.path.join(output_dir, f"{input_filename}_predictions.txt")
|
| 821 |
-
|
| 822 |
-
# Save predictions to file
|
| 823 |
print(f"Saving predictions to {output_file}...")
|
| 824 |
save_predictions_to_txt(predictions_dict, output_file)
|
| 825 |
-
|
| 826 |
-
print(f"Predictions saved successfully!")
|
| 827 |
print(f"Total sequences processed: {len(embeddings)}")
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
# Print a few sample predictions
|
| 831 |
print("\nSample predictions:")
|
| 832 |
for i, (seq_id, (classes, probs)) in enumerate(list(predictions_dict.items())[:3]):
|
| 833 |
pred_str = ", ".join([f"{cls} ({prob:.4f})" for cls, prob in zip(classes, probs)])
|
| 834 |
-
print(f"{seq_id}: {pred_str}")
|
|
|
|
|
|
| 7 |
from urllib.error import HTTPError
|
| 8 |
from typing import Literal, Optional
|
| 9 |
import tkinter as tk
|
| 10 |
+
from tkinter import filedialog, messagebox, ttk
|
| 11 |
|
| 12 |
|
| 13 |
import pandas as pd
|
|
|
|
| 39 |
from esm.models.esmc import ESMC
|
| 40 |
from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
|
| 41 |
from transformers import T5Tokenizer, T5EncoderModel, PreTrainedModel
|
| 42 |
+
from esm.sdk.forge import ESM3ForgeInferenceClient
|
| 43 |
|
| 44 |
from joblib import load
|
| 45 |
|
|
|
|
| 504 |
|
| 505 |
return idx, sequence
|
| 506 |
|
| 507 |
+
def fetch_sequences_for_dataframe(df: pd.DataFrame,
|
| 508 |
+
batch_size: Optional[int] = None,
|
| 509 |
+
max_workers: int = 5) -> pd.DataFrame:
|
|
|
|
| 510 |
"""
|
| 511 |
Add a 'sequence' column to the dataframe by fetching sequences from
|
| 512 |
SwissProt or RefSeq based on available IDs, with parallel execution and a progress bar.
|
|
|
|
| 553 |
f"({round(success_count/total_rows*100, 2)}%)")
|
| 554 |
return result_df
|
| 555 |
|
| 556 |
+
def esm_embed(model: ESMC,
|
| 557 |
+
seq : str,
|
| 558 |
+
acc : str,
|
| 559 |
+
device : torch.device = torch.device(
|
| 560 |
+
'cuda' if torch.cuda.is_available()
|
| 561 |
+
else 'cpu'
|
| 562 |
+
)) -> Optional[np.ndarray]:
|
| 563 |
|
| 564 |
"""
|
| 565 |
+
Generates an embedding for a given protein sequence using an ESM model.
|
| 566 |
Args:
|
| 567 |
+
model (ESMC): The ESM model used for encoding and generating embeddings.
|
| 568 |
+
seq (str): The amino acid sequence of the protein.
|
| 569 |
+
acc (str): The accession identifier for the protein (used for error reporting).
|
| 570 |
+
device (torch.device, optional): The device to run the computation on. Defaults to CUDA if available, otherwise CPU.
|
| 571 |
Returns:
|
| 572 |
+
Optional[np.ndarray]: The embedding vector for the protein sequence, or None if embedding could not be generated.
|
| 573 |
+
Raises:
|
| 574 |
+
ESMProteinError: If there is an error during protein encoding or embedding generation.
|
| 575 |
+
Side Effects:
|
| 576 |
+
Displays an error message using `messagebox.showerror` if an error occurs during processing.
|
| 577 |
"""
|
| 578 |
|
| 579 |
+
protein : ESMProtein = ESMProtein(sequence = seq)
|
| 580 |
+
protein_tensor = model.encode(protein).to(device)
|
|
|
|
|
|
|
| 581 |
|
| 582 |
if isinstance(protein_tensor, ESMProteinError):
|
| 583 |
+
messagebox.showerror("Error", f"Error processing {acc}: {protein_tensor}")
|
| 584 |
raise protein_tensor
|
| 585 |
+
|
| 586 |
+
try:
|
| 587 |
|
| 588 |
+
output : LogitsOutput = model.logits(protein_tensor,
|
| 589 |
+
LogitsConfig(sequence=True,
|
| 590 |
+
return_embeddings=True))
|
| 591 |
+
|
| 592 |
+
if output is not None and output.embeddings is not None:
|
| 593 |
+
arr_output : np.ndarray = output.embeddings.cpu().numpy()
|
| 594 |
+
|
| 595 |
+
if len(arr_output.shape) == 3:
|
| 596 |
+
arr_output = arr_output.squeeze(axis=0).mean(axis=0)
|
| 597 |
+
|
| 598 |
+
elif len(arr_output.shape) == 2:
|
| 599 |
+
arr_output = arr_output.mean(axis=0)
|
| 600 |
+
|
| 601 |
+
return arr_output
|
| 602 |
+
except ESMProteinError as e:
|
| 603 |
+
messagebox.showerror("Error", f"Error processing {acc}: {e}")
|
| 604 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
+
def predict_with_esm(fasta_path : str,
|
| 607 |
+
model : Literal['esmc_600m', 'esmc_300m'],
|
| 608 |
+
device : torch.device = torch.device('cuda' if torch.cuda.is_available()
|
| 609 |
+
else 'cpu'),
|
| 610 |
+
) -> None:
|
| 611 |
+
if fasta_path is None or not os.path.exists(fasta_path):
|
| 612 |
+
messagebox.showerror("Error", "Invalid FASTA file path.")
|
| 613 |
+
return
|
| 614 |
+
|
| 615 |
+
result = fasta_to_seq(fasta_path)
|
| 616 |
+
if result is None:
|
| 617 |
+
messagebox.showerror("Error", "No sequences found in FASTA file.")
|
| 618 |
+
return
|
| 619 |
+
seq, ids = result
|
| 620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
|
|
|
| 622 |
|
|
|
|
|
|
|
| 623 |
|
|
|
|
|
|
|
|
|
|
| 624 |
|
| 625 |
def prost_embed_sequence(seq : str,
|
| 626 |
acc : str,
|
|
|
|
| 737 |
|
| 738 |
f.write(f"{seq_id},{pred_line}\n")
|
| 739 |
|
| 740 |
+
def predict_with_prost(fasta_path: str):
|
| 741 |
"""
|
| 742 |
+
Function to embed sequences from a provided FASTA file using ProstT5 and predict locations.
|
| 743 |
"""
|
| 744 |
+
if not fasta_path or not os.path.exists(fasta_path):
|
| 745 |
+
print("Invalid FASTA file path.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
return
|
| 747 |
|
| 748 |
+
# Ask user for output directory
|
| 749 |
+
root = tk.Tk()
|
| 750 |
+
root.withdraw() # Hide root window
|
|
|
|
|
|
|
| 751 |
|
| 752 |
+
output_dir = filedialog.askdirectory(title="Select output directory")
|
| 753 |
if not output_dir:
|
|
|
|
| 754 |
return
|
| 755 |
|
| 756 |
result = fasta_to_seq(fasta_path)
|
|
|
|
| 757 |
if result is None:
|
| 758 |
+
messagebox.showerror("Error", "No sequences found in FASTA file.")
|
| 759 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
|
| 761 |
+
sequences, ids = result
|
| 762 |
+
total = len(sequences)
|
| 763 |
+
|
| 764 |
+
# Create progress bar window
|
| 765 |
+
progress_win = tk.Toplevel(root)
|
| 766 |
+
progress_win.title("Embedding Progress")
|
| 767 |
+
progress_label = tk.Label(progress_win, text="Embedding sequences...")
|
| 768 |
+
progress_label.pack(padx=10, pady=5)
|
| 769 |
+
progress = ttk.Progressbar(progress_win, length=300, mode='determinate', maximum=total)
|
| 770 |
+
progress.pack(padx=10, pady=10)
|
| 771 |
|
| 772 |
+
# Load model/tokenizer once
|
| 773 |
+
tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False, legacy=True)
|
| 774 |
+
model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
|
| 775 |
+
|
| 776 |
+
embeddings = {}
|
| 777 |
+
|
| 778 |
+
for i, (seq, acc) in enumerate(zip(sequences, ids)):
|
| 779 |
+
emb = prost_embed_sequence(seq, acc, tokenizer, model)
|
| 780 |
+
if emb is not None:
|
| 781 |
+
embeddings[acc] = emb
|
| 782 |
+
|
| 783 |
+
# Update progress
|
| 784 |
+
progress['value'] = i + 1
|
| 785 |
+
progress_win.update_idletasks() # Keeps the window responsive
|
| 786 |
+
|
| 787 |
+
progress_label.config(text="Embedding complete!")
|
| 788 |
+
tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
|
| 789 |
+
|
| 790 |
+
# Load model
|
| 791 |
+
messagebox.showinfo("Info", "Loading random forest model for predictions...")
|
| 792 |
+
project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 793 |
+
model_path = os.path.join(project_root, 'Models', 'rfProst.joblib')
|
| 794 |
|
|
|
|
| 795 |
try:
|
| 796 |
+
predictor = load(model_path)
|
| 797 |
except FileNotFoundError:
|
| 798 |
+
print(f"Error: Could not find the model file '{model_path}'")
|
|
|
|
| 799 |
return
|
| 800 |
|
| 801 |
sequence_ids = list(embeddings.keys())
|
| 802 |
+
X = np.array(list(embeddings.values())) # type: ignore
|
| 803 |
+
|
| 804 |
print("Making predictions...")
|
| 805 |
y_pred_proba = predictor.predict_proba(X)
|
| 806 |
+
|
| 807 |
+
# Get class names
|
| 808 |
if hasattr(predictor, 'classes_'):
|
| 809 |
+
class_names = [str(cls) for cls in predictor.classes_]
|
| 810 |
else:
|
| 811 |
+
class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
|
| 812 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
predictions_dict = {}
|
| 814 |
for i, seq_id in enumerate(sequence_ids):
|
| 815 |
+
class_prob_pairs = sorted(zip(class_names, y_pred_proba[i]), key=lambda x: x[1], reverse=True)
|
|
|
|
| 816 |
sorted_classes, sorted_probs = zip(*class_prob_pairs)
|
| 817 |
predictions_dict[seq_id] = (list(sorted_classes), list(sorted_probs))
|
| 818 |
+
|
| 819 |
+
# Save results
|
| 820 |
input_filename = os.path.splitext(os.path.basename(fasta_path))[0]
|
| 821 |
output_file = os.path.join(output_dir, f"{input_filename}_predictions.txt")
|
| 822 |
+
|
|
|
|
| 823 |
print(f"Saving predictions to {output_file}...")
|
| 824 |
save_predictions_to_txt(predictions_dict, output_file)
|
| 825 |
+
print("Predictions saved successfully!")
|
|
|
|
| 826 |
print(f"Total sequences processed: {len(embeddings)}")
|
| 827 |
+
|
|
|
|
|
|
|
| 828 |
print("\nSample predictions:")
|
| 829 |
for i, (seq_id, (classes, probs)) in enumerate(list(predictions_dict.items())[:3]):
|
| 830 |
pred_str = ", ".join([f"{cls} ({prob:.4f})" for cls, prob in zip(classes, probs)])
|
| 831 |
+
print(f"{seq_id}: {pred_str}")
|
| 832 |
+
|