jpuglia commited on
Commit
2491d7e
·
1 Parent(s): 5b2ffea

Update GUI title and button text for clarity in Protein Location Predictor

Browse files
Files changed (2) hide show
  1. gui.py +2 -2
  2. src/my_utils.py +18 -28
gui.py CHANGED
@@ -104,7 +104,7 @@ def menu():
104
  # root window
105
  root = tk.Tk()
106
  root.geometry('320x200')
107
- root.title('Protein Tools Menu')
108
 
109
  # create a menubar
110
  menubar = Menu(root)
@@ -123,7 +123,7 @@ def menu():
123
  help_menu.add_command(label='About...')
124
  menubar.add_cascade(label="Help", menu=help_menu, underline=0)
125
 
126
- btn_prost = tk.Button(root, text="Predict with Prost",
127
  command=run_prost) #Predict with Prost
128
  btn_prost.pack(pady=5)
129
 
 
104
  # root window
105
  root = tk.Tk()
106
  root.geometry('320x200')
107
+ root.title('Protein Location Predictor')
108
 
109
  # create a menubar
110
  menubar = Menu(root)
 
123
  help_menu.add_command(label='About...')
124
  menubar.add_cascade(label="Help", menu=help_menu, underline=0)
125
 
126
+ btn_prost = tk.Button(root, text="Predict with Prost T5",
127
  command=run_prost) #Predict with Prost
128
  btn_prost.pack(pady=5)
129
 
src/my_utils.py CHANGED
@@ -49,27 +49,20 @@ from joblib import load
49
 
50
  import torch
51
 
52
- # Load one chunk of embeddings
53
  def load_emb(path: str, acc: list[str]) -> np.ndarray:
54
-
55
  """
56
- Loads and processes embedding files from a specified directory.
57
- For each accession in the provided list, this function loads the corresponding
58
- NumPy `.npy` file from the given path, processes the embedding by averaging
59
- over axes if necessary, and collects the results.
 
60
  Args:
61
- path (str): Directory path containing the embedding `.npy` files.
62
  acc (list[str]): List of accession identifiers corresponding to the embedding files.
63
  Returns:
64
- tuple[np.ndarray, np.ndarray]:
65
- - A 2D NumPy array where each row is a processed embedding.
66
- - A 1D NumPy array of accession identifiers corresponding to the embeddings.
67
  Raises:
68
  FileNotFoundError: If the specified path does not exist.
69
- Notes:
70
- - If an embedding has 3 dimensions, it is squeezed along axis 0 and then averaged over axis 0.
71
- - If an embedding has 2 dimensions, it is averaged over axis 0.
72
- - Otherwise, the embedding is used as is.
73
  """
74
 
75
  if not os.path.exists(path):
@@ -78,7 +71,6 @@ def load_emb(path: str, acc: list[str]) -> np.ndarray:
78
  total_files = len([f for f in os.listdir(path) if f.endswith('.npy')])
79
 
80
  x = []
81
- y = []
82
 
83
  for a in tqdm(acc, desc = 'Cargando embeddings', total=total_files):
84
 
@@ -88,20 +80,20 @@ def load_emb(path: str, acc: list[str]) -> np.ndarray:
88
  emb = emb.squeeze(axis = 0)
89
  emb = emb.mean(axis = 0)
90
  x.append(emb)
91
- y.append(a)
92
  elif len(emb.shape) == 2:
93
  emb = emb.mean(axis = 0)
94
  x.append(emb)
95
- y.append(a)
96
  else:
97
  x.append(emb)
98
- y.append(a)
99
 
100
  return np.vstack(x)
101
 
102
  def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
103
 
104
- """ Plot a confusion matrix for the given true and predicted labels.
 
105
  Args:
106
  title (str): Title for the confusion matrix plot.
107
  y_true (np.ndarray): True labels.
@@ -125,17 +117,15 @@ def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
125
 
126
  def plot_umap(x: np.ndarray, y: np.ndarray, title: str) -> None:
127
  """
128
- Plot a 2D UMAP projection of high-dimensional data with color-coded labels and hover information.
129
-
130
- Args:
131
- x (list[np.ndarray]): List of feature arrays to be concatenated and visualized.
132
- y (list[str]): List of labels corresponding to each sample in x, used for coloring the scatter plot.
133
- title (str): Title of the plot.
134
- org (list[str]): List of organism or group identifiers for each sample, shown in hover data.
135
-
136
  Returns:
137
- None: Displays an interactive UMAP scatter plot using Plotly.
138
  """
 
139
  reducer = umap.UMAP(n_neighbors=30, random_state=42)
140
 
141
  scaled_x = StandardScaler().fit_transform(x)
 
49
 
50
  import torch
51
 
 
52
  def load_emb(path: str, acc: list[str]) -> np.ndarray:
 
53
  """
54
+ Loads and processes embedding files from a specified directory for a list of accession identifiers.
55
+ Each embedding is expected to be stored as a .npy file named after its accession in the given path.
56
+ - If the embedding has 3 dimensions, it is squeezed along the first axis and then averaged along the next axis.
57
+ - If the embedding has 2 dimensions, it is averaged along the first axis.
58
+ - Otherwise, the embedding is used as is.
59
  Args:
60
+ path (str): Directory path where the embedding .npy files are stored.
61
  acc (list[str]): List of accession identifiers corresponding to the embedding files.
62
  Returns:
63
+ np.ndarray: A 2D array where each row corresponds to the processed embedding of an accession.
 
 
64
  Raises:
65
  FileNotFoundError: If the specified path does not exist.
 
 
 
 
66
  """
67
 
68
  if not os.path.exists(path):
 
71
  total_files = len([f for f in os.listdir(path) if f.endswith('.npy')])
72
 
73
  x = []
 
74
 
75
  for a in tqdm(acc, desc = 'Cargando embeddings', total=total_files):
76
 
 
80
  emb = emb.squeeze(axis = 0)
81
  emb = emb.mean(axis = 0)
82
  x.append(emb)
83
+
84
  elif len(emb.shape) == 2:
85
  emb = emb.mean(axis = 0)
86
  x.append(emb)
87
+
88
  else:
89
  x.append(emb)
 
90
 
91
  return np.vstack(x)
92
 
93
  def confusion(title : str, y_true: np.ndarray, y_pred: np.ndarray) -> None:
94
 
95
+ """
96
+ Plot a confusion matrix for the given true and predicted labels.
97
  Args:
98
  title (str): Title for the confusion matrix plot.
99
  y_true (np.ndarray): True labels.
 
117
 
118
  def plot_umap(x: np.ndarray, y: np.ndarray, title: str) -> None:
119
  """
120
+ Plots a 2D UMAP projection of high-dimensional data with class labels.
121
+ Parameters:
122
+ x (np.ndarray): The input feature matrix of shape (n_samples, n_features).
123
+ y (np.ndarray): The array of labels corresponding to each sample.
124
+ title (str): The title for the plot.
 
 
 
125
  Returns:
126
+ None: Displays a scatter plot of the UMAP embedding colored by label.
127
  """
128
+
129
  reducer = umap.UMAP(n_neighbors=30, random_state=42)
130
 
131
  scaled_x = StandardScaler().fit_transform(x)