jpuglia commited on
Commit
4b71c5b
·
1 Parent(s): 6587358

Enhance CLI functionality: Implement file loading and prediction features with user feedback in Tkinter GUI

Browse files
Files changed (2) hide show
  1. cli.py +64 -25
  2. src/my_utils.py +117 -119
cli.py CHANGED
@@ -1,19 +1,69 @@
 
 
 
 
 
1
  import tkinter as tk
2
- from tkinter import Menu
3
  from src.my_utils import predict_with_prost
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def menu():
7
  """
8
- Creates and displays the main GUI menu for the Protein Tools application using Tkinter.
9
 
10
- The menu includes:
11
- - A "File" menu with options for creating a new file, opening, closing, preferences (with sub-menu for keyboard shortcuts and color themes), and exiting the application.
12
- - A "Help" menu with options for welcome and about dialogs.
13
- - Two buttons below the menu: one for loading a FASTA file (triggers `predict_with_prost`), and one for exiting the application.
14
 
15
- Returns:
16
- None
 
 
 
 
 
 
17
  """
18
  # root window
19
  root = tk.Tk()
@@ -26,19 +76,10 @@ def menu():
26
 
27
  # create the file_menu
28
  file_menu = Menu(menubar, tearoff=0)
29
- file_menu.add_command(label='New')
30
- file_menu.add_command(label='Open...')
31
- file_menu.add_command(label='Close')
32
- file_menu.add_separator()
33
-
34
- sub_menu = Menu(file_menu, tearoff=0)
35
- sub_menu.add_command(label='Keyboard Shortcuts')
36
- sub_menu.add_command(label='Color Themes')
37
-
38
- file_menu.add_cascade(label="Preferences", menu=sub_menu)
39
  file_menu.add_separator()
40
- file_menu.add_command(label='Exit', command=root.destroy)
41
- menubar.add_cascade(label="File", menu=file_menu, underline=0)
42
 
43
  # help menu
44
  help_menu = Menu(menubar, tearoff=0)
@@ -50,17 +91,15 @@ def menu():
50
  # Add Buttons Below Menu
51
  # =========================
52
 
53
- btn_prost = tk.Button(root, text="Predict with Prost", command=predict_with_prost)
54
  btn_prost.pack(pady=5)
55
 
56
- btn_ESM = tk.Button(root, text="Predict with ESMC", command=print(NotImplementedError("ESM functionality not implemented yet."))) #type: ignore
57
- btn_ESM.pack(pady=5)
58
 
59
  btn_exit = tk.Button(root, text="Exit", command=root.quit)
60
  btn_exit.pack(pady=5)
61
 
62
  root.mainloop()
63
 
64
-
65
-
66
  menu()
 
1
+ """
2
+ Protein Location Predictor CLI
3
+ This module provides a Tkinter-based GUI for loading FASTA files
4
+ and running protein location prediction tools.
5
+ """
6
  import tkinter as tk
7
+ from tkinter import Menu, filedialog, messagebox
8
  from src.my_utils import predict_with_prost
9
 
10
 
11
+ FASTA_FILE_PATH = None # Global or instance variable
12
+
13
+ def load_fasta_file():
14
+ """
15
+ Opens a file dialog for the user to select a FASTA file and stores the selected file path in a global variable.
16
+
17
+ If a file is selected, displays an information message with the file path.
18
+ If no file is selected, displays a warning message.
19
+
20
+ Uses:
21
+ - filedialog.askopenfilename for file selection.
22
+ - messagebox.showinfo and messagebox.showwarning for user feedback.
23
+
24
+ Global Variables:
25
+ FASTA_FILE_PATH (str): Path to the selected FASTA file.
26
+ """
27
+ global FASTA_FILE_PATH # pylint: disable=global-statement
28
+ FASTA_FILE_PATH = filedialog.askopenfilename(
29
+ filetypes=[("FASTA files", "*.fasta *.fa")],
30
+ title="Select a FASTA file"
31
+ )
32
+ if FASTA_FILE_PATH:
33
+ messagebox.showinfo("File Loaded", f"Loaded file:\n{FASTA_FILE_PATH}")
34
+ else:
35
+ messagebox.showwarning("No file", "No file was selected.")
36
+
37
+ def run_prediction():
38
+ """
39
+ Runs the protein location prediction process.
40
+
41
+ Checks if a FASTA file path is provided. If not, displays an error message to the user.
42
+ If a FASTA file is loaded, proceeds to run the prediction using the PROST model.
43
+
44
+ Raises:
45
+ Shows a message box error if no FASTA file is loaded.
46
+ """
47
+ if not FASTA_FILE_PATH:
48
+ messagebox.showerror("Error", "Please load a FASTA file first.")
49
+ return
50
+ predict_with_prost(FASTA_FILE_PATH)
51
+
52
  def menu():
53
  """
54
+ Displays the main GUI menu for the Protein Tools application.
55
 
56
+ This function creates a Tkinter window with a menu bar containing 'File' and 'Help' menus,
57
+ and buttons for running protein prediction tools and exiting the application.
 
 
58
 
59
+ Menus:
60
+ - File: Options to load a FASTA file or close the application.
61
+ - Help: Options for welcome information and about dialog.
62
+
63
+ Buttons:
64
+ - Predict with Prost: Runs the Prost prediction tool.
65
+ - Predict with ESM C: Placeholder for ESM prediction functionality (not yet implemented).
66
+ - Exit: Closes the application.
67
  """
68
  # root window
69
  root = tk.Tk()
 
76
 
77
  # create the file_menu
78
  file_menu = Menu(menubar, tearoff=0)
79
+ file_menu.add_command(label='Load FASTA', command=load_fasta_file)
80
+ file_menu.add_command(label='Close', command=root.quit)
 
 
 
 
 
 
 
 
81
  file_menu.add_separator()
82
+ menubar.add_cascade(label="File", menu=file_menu, underline=0)
 
83
 
84
  # help menu
85
  help_menu = Menu(menubar, tearoff=0)
 
91
  # Add Buttons Below Menu
92
  # =========================
93
 
94
+ btn_prost = tk.Button(root, text="Predict with Prost", command=run_prediction)
95
  btn_prost.pack(pady=5)
96
 
97
+ btn_esm = tk.Button(root, text="Predict with ESM C") #type: ignore
98
+ btn_esm.pack(pady=5)
99
 
100
  btn_exit = tk.Button(root, text="Exit", command=root.quit)
101
  btn_exit.pack(pady=5)
102
 
103
  root.mainloop()
104
 
 
 
105
  menu()
src/my_utils.py CHANGED
@@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
7
  from urllib.error import HTTPError
8
  from typing import Literal, Optional
9
  import tkinter as tk
10
- from tkinter import filedialog
11
 
12
 
13
  import pandas as pd
@@ -39,6 +39,7 @@ import plotly.express as px
39
  from esm.models.esmc import ESMC
40
  from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
41
  from transformers import T5Tokenizer, T5EncoderModel, PreTrainedModel
 
42
 
43
  from joblib import load
44
 
@@ -503,10 +504,9 @@ def _fetch_sequence_for_row(idx, row):
503
 
504
  return idx, sequence
505
 
506
-
507
-
508
-
509
- def fetch_sequences_for_dataframe(df: pd.DataFrame, batch_size: Optional[int] = None, max_workers: int = 5) -> pd.DataFrame:
510
  """
511
  Add a 'sequence' column to the dataframe by fetching sequences from
512
  SwissProt or RefSeq based on available IDs, with parallel execution and a progress bar.
@@ -553,69 +553,74 @@ def fetch_sequences_for_dataframe(df: pd.DataFrame, batch_size: Optional[int] =
553
  f"({round(success_count/total_rows*100, 2)}%)")
554
  return result_df
555
 
556
- def esm_embed_sequence(model : Literal["esmc_300m", "esmc_600m"], sequence : str, device : str) -> LogitsOutput:
 
 
 
 
 
 
557
 
558
  """
559
- Embed a protein sequence using the specified ESM model.
560
  Args:
561
- model: Name of the ESM model to use.
562
- sequence: Protein sequence to embed.
 
 
563
  Returns:
564
- LogitsOutput: Contains the embeddings and logits for the sequence.
 
 
 
 
565
  """
566
 
567
- client = ESMC.from_pretrained(model).to(device)
568
-
569
- protein = ESMProtein(sequence=sequence)
570
- protein_tensor = client.encode(protein)
571
 
572
  if isinstance(protein_tensor, ESMProteinError):
573
-
574
  raise protein_tensor
 
 
575
 
576
- output = client.logits(protein_tensor, LogitsConfig(sequence=True, return_embeddings=True))
577
-
578
- return output
579
-
580
- def esm_save_emb(model: Literal["esmc_300m", "esmc_600m"],
581
- seq_list: list[str],
582
- id_list: list[str],
583
- path: str,
584
- device : Literal['cuda', 'cpu'] = 'cuda') -> None:
585
-
586
- """
587
- Save embeddings to disk.
588
-
589
- Args:
590
- model: ESM model name. Options are "esmc_300m" or "esmc_600m".
591
- seq_list: List of protein sequences.
592
- id_list: List of identifiers corresponding to the sequences.
593
- path: Directory to save the embeddings.
594
- """
595
-
596
- assert len(seq_list) == len(id_list), "Sequence and ID lists must be the same length."
597
- os.makedirs(path, exist_ok=True)
598
 
599
- for i, (seq, acc) in enumerate(
600
- tqdm(zip(seq_list, id_list),
601
- total=len(seq_list), desc="Saving embeddings")):
602
- try:
603
- output: LogitsOutput = esm_embed_sequence(model=model, sequence=seq, device = device)
604
- emb_array = output.embeddings.cpu().numpy()
 
 
 
 
 
 
 
 
605
 
606
- if len(emb_array.shape) == 3:
607
- emb_array = emb_array.squeeze(axis=0).mean(axis=0)
608
- elif len(emb_array.shape) == 2:
609
- emb_array = emb_array.mean(axis=0)
610
 
611
- np.save(os.path.join(path, f"{acc}.npy"), emb_array)
612
 
613
- except ESMProteinError as e:
614
- print(f"Error processing {acc}: {e}")
615
 
616
- if i % 100 == 0:
617
- gc.collect()
618
- torch.cuda.empty_cache()
619
 
620
  def prost_embed_sequence(seq : str,
621
  acc : str,
@@ -732,103 +737,96 @@ def save_predictions_to_txt(predictions_dict: dict[str, tuple[list[str], list[fl
732
 
733
  f.write(f"{seq_id},{pred_line}\n")
734
 
735
- def predict_with_prost():
736
  """
737
- Function to select a directory containing FASTA files and embed sequences using ProstT5.
738
  """
739
- root = tk.Tk()
740
- root.withdraw()
741
-
742
- fasta_path : str = filedialog.askopenfilename(
743
- title="Select a FASTA file",
744
- filetypes=[("FASTA files", "*.fasta *.fa")],
745
- initialdir="."
746
- )
747
-
748
- if not fasta_path:
749
- print("No file selected.")
750
  return
751
 
752
- # Select output directory for results
753
- output_dir: str = filedialog.askdirectory(
754
- title="Select output directory for results",
755
- initialdir="."
756
- )
757
 
 
758
  if not output_dir:
759
- print("No output directory selected.")
760
  return
761
 
762
  result = fasta_to_seq(fasta_path)
763
-
764
  if result is None:
765
- print("No sequences found in the FASTA file.")
766
- return {}
767
- else:
768
- sequences, ids = result
769
- print(f"Sequences loaded from {fasta_path}: {len(sequences)} sequences found.")
770
- print("Embedding sequences using ProstT5...")
771
 
772
- tokenizer : T5Tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False)
773
- model : PreTrainedModel = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
 
 
 
 
 
 
 
 
774
 
775
- embeddings : dict[str, np.ndarray] = {}
776
-
777
- for seq, acc in tqdm(zip(sequences, ids), total=len(sequences), desc="Embedding sequences"):
778
- emb = prost_embed_sequence(seq, acc, tokenizer, model)
779
- if emb is not None:
780
- embeddings[acc] = emb
781
- else:
782
- print(f"Failed to embed sequence {acc}. Skipping.")
783
-
784
- print(f"Embedded {len(embeddings)} sequences successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
785
 
786
- print("Loading pre-trained SVM model for prediction...")
787
  try:
788
- predictor = load('/home/juan/ProteinLocationPredictor/Models/rfProst.joblib')
789
  except FileNotFoundError:
790
- print("Error: Could not find the model file '../ProteinLocationPredictor/Models/svmProst.joblib'")
791
- print("Please check the path to your trained model.")
792
  return
793
 
794
  sequence_ids = list(embeddings.keys())
795
- X = np.array(list(embeddings.values())) #type: ignore
 
796
  print("Making predictions...")
797
  y_pred_proba = predictor.predict_proba(X)
798
-
799
- # Get class names (you may need to adjust this based on your model)
800
  if hasattr(predictor, 'classes_'):
801
- class_names = predictor.classes_.tolist()
802
  else:
803
- # If class names are not available, use generic names
804
- n_classes = y_pred_proba.shape[1]
805
- class_names = [f"Class_{i}" for i in range(n_classes)]
806
-
807
- # Convert class names to strings if they aren't already
808
- class_names = [str(cls) for cls in class_names]
809
-
810
- # Create predictions dictionary
811
  predictions_dict = {}
812
  for i, seq_id in enumerate(sequence_ids):
813
- probabilities = y_pred_proba[i].tolist()
814
- class_prob_pairs = sorted(zip(class_names, probabilities), key=lambda x: x[1], reverse=True)
815
  sorted_classes, sorted_probs = zip(*class_prob_pairs)
816
  predictions_dict[seq_id] = (list(sorted_classes), list(sorted_probs))
817
-
818
- # Generate output filename
819
  input_filename = os.path.splitext(os.path.basename(fasta_path))[0]
820
  output_file = os.path.join(output_dir, f"{input_filename}_predictions.txt")
821
-
822
- # Save predictions to file
823
  print(f"Saving predictions to {output_file}...")
824
  save_predictions_to_txt(predictions_dict, output_file)
825
-
826
- print(f"Predictions saved successfully!")
827
  print(f"Total sequences processed: {len(embeddings)}")
828
- print(f"Output file: {output_file}")
829
-
830
- # Print a few sample predictions
831
  print("\nSample predictions:")
832
  for i, (seq_id, (classes, probs)) in enumerate(list(predictions_dict.items())[:3]):
833
  pred_str = ", ".join([f"{cls} ({prob:.4f})" for cls, prob in zip(classes, probs)])
834
- print(f"{seq_id}: {pred_str}")
 
 
7
  from urllib.error import HTTPError
8
  from typing import Literal, Optional
9
  import tkinter as tk
10
+ from tkinter import filedialog, messagebox, ttk
11
 
12
 
13
  import pandas as pd
 
39
  from esm.models.esmc import ESMC
40
  from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinError, LogitsOutput
41
  from transformers import T5Tokenizer, T5EncoderModel, PreTrainedModel
42
+ from esm.sdk.forge import ESM3ForgeInferenceClient
43
 
44
  from joblib import load
45
 
 
504
 
505
  return idx, sequence
506
 
507
+ def fetch_sequences_for_dataframe(df: pd.DataFrame,
508
+ batch_size: Optional[int] = None,
509
+ max_workers: int = 5) -> pd.DataFrame:
 
510
  """
511
  Add a 'sequence' column to the dataframe by fetching sequences from
512
  SwissProt or RefSeq based on available IDs, with parallel execution and a progress bar.
 
553
  f"({round(success_count/total_rows*100, 2)}%)")
554
  return result_df
555
 
556
+ def esm_embed(model: ESMC,
557
+ seq : str,
558
+ acc : str,
559
+ device : torch.device = torch.device(
560
+ 'cuda' if torch.cuda.is_available()
561
+ else 'cpu'
562
+ )) -> Optional[np.ndarray]:
563
 
564
  """
565
+ Generates an embedding for a given protein sequence using an ESM model.
566
  Args:
567
+ model (ESMC): The ESM model used for encoding and generating embeddings.
568
+ seq (str): The amino acid sequence of the protein.
569
+ acc (str): The accession identifier for the protein (used for error reporting).
570
+ device (torch.device, optional): The device to run the computation on. Defaults to CUDA if available, otherwise CPU.
571
  Returns:
572
+ Optional[np.ndarray]: The embedding vector for the protein sequence, or None if embedding could not be generated.
573
+ Raises:
574
+ ESMProteinError: If there is an error during protein encoding or embedding generation.
575
+ Side Effects:
576
+ Displays an error message using `messagebox.showerror` if an error occurs during processing.
577
  """
578
 
579
+ protein : ESMProtein = ESMProtein(sequence = seq)
580
+ protein_tensor = model.encode(protein).to(device)
 
 
581
 
582
  if isinstance(protein_tensor, ESMProteinError):
583
+ messagebox.showerror("Error", f"Error processing {acc}: {protein_tensor}")
584
  raise protein_tensor
585
+
586
+ try:
587
 
588
+ output : LogitsOutput = model.logits(protein_tensor,
589
+ LogitsConfig(sequence=True,
590
+ return_embeddings=True))
591
+
592
+ if output is not None and output.embeddings is not None:
593
+ arr_output : np.ndarray = output.embeddings.cpu().numpy()
594
+
595
+ if len(arr_output.shape) == 3:
596
+ arr_output = arr_output.squeeze(axis=0).mean(axis=0)
597
+
598
+ elif len(arr_output.shape) == 2:
599
+ arr_output = arr_output.mean(axis=0)
600
+
601
+ return arr_output
602
+ except ESMProteinError as e:
603
+ messagebox.showerror("Error", f"Error processing {acc}: {e}")
604
+ return
 
 
 
 
 
605
 
606
+ def predict_with_esm(fasta_path : str,
607
+ model : Literal['esmc_600m', 'esmc_300m'],
608
+ device : torch.device = torch.device('cuda' if torch.cuda.is_available()
609
+ else 'cpu'),
610
+ ) -> None:
611
+ if fasta_path is None or not os.path.exists(fasta_path):
612
+ messagebox.showerror("Error", "Invalid FASTA file path.")
613
+ return
614
+
615
+ result = fasta_to_seq(fasta_path)
616
+ if result is None:
617
+ messagebox.showerror("Error", "No sequences found in FASTA file.")
618
+ return
619
+ seq, ids = result
620
 
 
 
 
 
621
 
 
622
 
 
 
623
 
 
 
 
624
 
625
  def prost_embed_sequence(seq : str,
626
  acc : str,
 
737
 
738
  f.write(f"{seq_id},{pred_line}\n")
739
 
740
+ def predict_with_prost(fasta_path: str):
741
  """
742
+ Function to embed sequences from a provided FASTA file using ProstT5 and predict locations.
743
  """
744
+ if not fasta_path or not os.path.exists(fasta_path):
745
+ print("Invalid FASTA file path.")
 
 
 
 
 
 
 
 
 
746
  return
747
 
748
+ # Ask user for output directory
749
+ root = tk.Tk()
750
+ root.withdraw() # Hide root window
 
 
751
 
752
+ output_dir = filedialog.askdirectory(title="Select output directory")
753
  if not output_dir:
 
754
  return
755
 
756
  result = fasta_to_seq(fasta_path)
 
757
  if result is None:
758
+ messagebox.showerror("Error", "No sequences found in FASTA file.")
759
+ return
 
 
 
 
760
 
761
+ sequences, ids = result
762
+ total = len(sequences)
763
+
764
+ # Create progress bar window
765
+ progress_win = tk.Toplevel(root)
766
+ progress_win.title("Embedding Progress")
767
+ progress_label = tk.Label(progress_win, text="Embedding sequences...")
768
+ progress_label.pack(padx=10, pady=5)
769
+ progress = ttk.Progressbar(progress_win, length=300, mode='determinate', maximum=total)
770
+ progress.pack(padx=10, pady=10)
771
 
772
+ # Load model/tokenizer once
773
+ tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False, legacy=True)
774
+ model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
775
+
776
+ embeddings = {}
777
+
778
+ for i, (seq, acc) in enumerate(zip(sequences, ids)):
779
+ emb = prost_embed_sequence(seq, acc, tokenizer, model)
780
+ if emb is not None:
781
+ embeddings[acc] = emb
782
+
783
+ # Update progress
784
+ progress['value'] = i + 1
785
+ progress_win.update_idletasks() # Keeps the window responsive
786
+
787
+ progress_label.config(text="Embedding complete!")
788
+ tk.Button(progress_win, text="Close", command=progress_win.destroy).pack(pady=5)
789
+
790
+ # Load model
791
+ messagebox.showinfo("Info", "Loading random forest model for predictions...")
792
+ project_root: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
793
+ model_path = os.path.join(project_root, 'Models', 'rfProst.joblib')
794
 
 
795
  try:
796
+ predictor = load(model_path)
797
  except FileNotFoundError:
798
+ print(f"Error: Could not find the model file '{model_path}'")
 
799
  return
800
 
801
  sequence_ids = list(embeddings.keys())
802
+ X = np.array(list(embeddings.values())) # type: ignore
803
+
804
  print("Making predictions...")
805
  y_pred_proba = predictor.predict_proba(X)
806
+
807
+ # Get class names
808
  if hasattr(predictor, 'classes_'):
809
+ class_names = [str(cls) for cls in predictor.classes_]
810
  else:
811
+ class_names = [f"Class_{i}" for i in range(y_pred_proba.shape[1])]
812
+
 
 
 
 
 
 
813
  predictions_dict = {}
814
  for i, seq_id in enumerate(sequence_ids):
815
+ class_prob_pairs = sorted(zip(class_names, y_pred_proba[i]), key=lambda x: x[1], reverse=True)
 
816
  sorted_classes, sorted_probs = zip(*class_prob_pairs)
817
  predictions_dict[seq_id] = (list(sorted_classes), list(sorted_probs))
818
+
819
+ # Save results
820
  input_filename = os.path.splitext(os.path.basename(fasta_path))[0]
821
  output_file = os.path.join(output_dir, f"{input_filename}_predictions.txt")
822
+
 
823
  print(f"Saving predictions to {output_file}...")
824
  save_predictions_to_txt(predictions_dict, output_file)
825
+ print("Predictions saved successfully!")
 
826
  print(f"Total sequences processed: {len(embeddings)}")
827
+
 
 
828
  print("\nSample predictions:")
829
  for i, (seq_id, (classes, probs)) in enumerate(list(predictions_dict.items())[:3]):
830
  pred_str = ", ".join([f"{cls} ({prob:.4f})" for cls, prob in zip(classes, probs)])
831
+ print(f"{seq_id}: {pred_str}")
832
+