jpuglia commited on
Commit
923a7e5
·
1 Parent(s): 816b1ef

Refactor my_utils.py: Simplify type hints, enhance evaluate and training functions, and improve error handling in sequence fetching

Browse files
Files changed (2) hide show
  1. notebooks/hyperparamsRF.ipynb +2 -2
  2. src/my_utils.py +91 -49
notebooks/hyperparamsRF.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed5fec8d6f5354ecaef873661bc650c07f91e4e425b3d20c5f221ab8d1d21b11
3
- size 707241
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be08020829c6e68c1b659bca93f71ede388f4c5d6fba3b7bd4aa85b363806f28
3
+ size 101568
src/my_utils.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import re
4
  from pprint import pprint
5
  from io import StringIO
6
- from typing import Literal, Optional, Union
7
  import tkinter as tk
8
  from tkinter import filedialog, messagebox, ttk
9
 
@@ -14,17 +14,26 @@ import numpy as np
14
  from sklearn.ensemble import RandomForestClassifier
15
  from sklearn import svm
16
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
17
- from sklearn.metrics import classification_report, accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
 
 
 
 
 
 
 
18
  from sklearn.decomposition import PCA
19
  from sklearn.preprocessing import StandardScaler, LabelEncoder
20
  from sklearn.pipeline import Pipeline
21
  from sklearn.manifold import TSNE
22
  from sklearn.model_selection import train_test_split
23
  from sklearn.utils import resample
 
24
 
25
  import umap
26
 
27
  import requests
 
28
  from Bio import Entrez
29
  from Bio import SeqIO
30
  from tqdm import tqdm
@@ -234,7 +243,9 @@ def plot_emb(x: np.ndarray, labels : np.ndarray, model_name: str):
234
  tsne_plot(x, labels, title=f't-SNE - {model_name}')
235
  plot_umap(x, labels, title=f'UMAP - {model_name}')
236
 
237
- def evaluate(model: Union[RandomForestClassifier, svm.SVC], X_test : np.ndarray, y_test : np.ndarray) -> dict:
 
 
238
 
239
  """
240
  Evaluates a classification model on test data and computes performance metrics.
@@ -253,7 +264,7 @@ def evaluate(model: Union[RandomForestClassifier, svm.SVC], X_test : np.ndarray,
253
  """
254
 
255
  result = {}
256
- y_pred = model.predict(X_test)
257
 
258
  result['Accuracy'] = accuracy_score(y_test, y_pred)
259
  result['Recall'] = recall_score(y_test, y_pred, average = 'weighted')
@@ -261,8 +272,6 @@ def evaluate(model: Union[RandomForestClassifier, svm.SVC], X_test : np.ndarray,
261
  result['F1'] = f1_score(y_test, y_pred, average='weighted')
262
 
263
  pprint(result)
264
-
265
-
266
  return result
267
 
268
 
@@ -270,9 +279,27 @@ def evaluate(model: Union[RandomForestClassifier, svm.SVC], X_test : np.ndarray,
270
  def train_rf(title: str,
271
  x: np.ndarray,
272
  y : np.ndarray,
273
- params: dict) -> tuple[RandomForestClassifier, dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- y_encoded = LabelEncoder().fit_transform(y)
 
276
 
277
  x_train, x_test, y_train, y_test = train_test_split(x, y_encoded, test_size=0.33, stratify=y_encoded, random_state=42)
278
 
@@ -287,61 +314,78 @@ def train_rf(title: str,
287
 
288
  evaluation = evaluate(classifier, x_test, y_test)
289
 
290
- print(classification_report(y_test, y_pred, zero_division=0))
291
-
292
- confusion(title = title,
293
- y_true = y_test,
294
- y_pred = y_pred)
 
 
295
 
296
- del x_train, x_test, y_train, y_test
297
 
298
- return classifier, evaluation
299
 
300
- def train_svm(title: str, x: np.ndarray, y: list[str], params: dict) -> tuple[Pipeline, dict]:
301
  """
302
- Train a Support Vector Machine (SVM) classifier with the provided data and parameters, evaluate its performance, and return the trained pipeline and evaluation metrics.
303
 
304
  Args:
305
  title (str): Title for the confusion matrix plot.
306
- x (np.ndarray): Feature matrix.
307
- y (list[str]): List of labels.
308
- params (dict): Dictionary of parameters for the SVM.
309
 
310
  Returns:
311
- tuple[Pipeline, dict]: The trained pipeline and a dictionary of evaluation metrics.
 
 
 
 
 
 
 
312
  """
 
 
 
 
313
  x_train, x_test, y_train, y_test = train_test_split(
314
- x, y, test_size=0.33, stratify=y, random_state=42
315
  )
316
 
317
  svc_params = {k.replace('svm__', ''): v for k, v in params.items() if k.startswith('svm__')}
318
  pipeline = Pipeline([
319
  ('scaler', StandardScaler()),
320
- ('svm', svm.SVC(**svc_params))
321
  ])
322
 
323
  pipeline.fit(x_train, y_train)
324
 
325
  y_pred = pipeline.predict(x_test)
326
 
327
- evaluation = evaluate(model=pipeline, X_test=x_test, y_test=y_test)
328
 
329
- confusion(title=title,
330
- y_true=y_test,
331
- y_pred=y_pred)
332
 
333
- print(classification_report(y_test, y_pred, zero_division=0))
334
 
335
- return pipeline, evaluation
336
 
 
337
 
338
- def randomSVM(X: list[np.ndarray], y = list[str]) -> dict:
339
 
340
- X_train, _, y_train, _ = train_test_split(X,
341
- y,
 
 
 
 
 
342
  test_size=0.33,
343
- stratify=y,
344
- random_state=42)
345
 
346
  pipeline = Pipeline([('scaler', StandardScaler()),
347
  ('svm', svm.SVC())])
@@ -365,28 +409,26 @@ def randomSVM(X: list[np.ndarray], y = list[str]) -> dict:
365
  n_iter=50,
366
  scoring='f1_weighted',
367
  cv=3,
368
- verbose=2,
369
  random_state=42,
370
  n_jobs=-1
371
  )
372
 
373
- random_search.fit(X_sample, y_sample)
 
374
 
375
  pprint(random_search.best_params_)
376
 
377
  return random_search.best_params_
378
 
379
- def randomSearch(X: np.ndarray, y: np.ndarray) -> dict:
380
 
381
- X_train, _, y_train, _ = train_test_split(X, y, test_size=0.33, stratify=y, random_state=42)
 
 
 
382
  classifier : RandomForestClassifier = RandomForestClassifier(random_state=42)
383
 
384
- X_sample, y_sample = resample(X_train,
385
- y_train,
386
- n_samples = 3500,
387
- stratify = y_train,
388
- random_state = 42) #type: ignore
389
-
390
  param_grid = {
391
  'n_estimators': list(np.arange(500,4000, 400)),
392
  'max_depth': [None, 10, 20, 30, 40, 50],
@@ -404,10 +446,10 @@ def randomSearch(X: np.ndarray, y: np.ndarray) -> dict:
404
  n_iter= 50,
405
  scoring = 'f1_weighted',
406
  cv = 3,
407
- verbose = 2,
408
  n_jobs = -1)
409
 
410
- rf_random.fit(X = X_sample, y = y_sample)
411
 
412
  print('Best Params')
413
  pprint(rf_random.best_params_)
@@ -534,7 +576,7 @@ def _fetch_sequence_for_row(idx, row):
534
  try:
535
  sequence = fetch_uniprot_sequence(swiss_id)
536
  except HTTPError as e:
537
- print(f"Warning: SwissProt fetch failed for {swiss_id} with HTTP {e.code}")
538
  sequence = None
539
 
540
  # Try RefSeq if no SwissProt
@@ -542,7 +584,7 @@ def _fetch_sequence_for_row(idx, row):
542
  try:
543
  sequence = fetch_refseq_sequence(row['Refseq_Accession'])
544
  except HTTPError as e:
545
- print(f"Warning: RefSeq fetch failed for {row['Refseq_Accession']} with HTTP {e.code}")
546
  sequence = None
547
 
548
  # Try Other_Accession if still no sequence
@@ -550,7 +592,7 @@ def _fetch_sequence_for_row(idx, row):
550
  try:
551
  sequence = fetch_refseq_sequence(row['Other_Accession'])
552
  except HTTPError as e:
553
- print(f"Warning: RefSeq fetch failed for {row['Other_Accession']} with HTTP {e.code}")
554
  sequence = None
555
 
556
  return idx, sequence
 
3
  import re
4
  from pprint import pprint
5
  from io import StringIO
6
+ from typing import Literal, Optional
7
  import tkinter as tk
8
  from tkinter import filedialog, messagebox, ttk
9
 
 
14
  from sklearn.ensemble import RandomForestClassifier
15
  from sklearn import svm
16
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
17
+ from sklearn.metrics import (
18
+ classification_report,
19
+ accuracy_score,
20
+ f1_score,
21
+ recall_score,
22
+ precision_score,
23
+ confusion_matrix,
24
+ )
25
  from sklearn.decomposition import PCA
26
  from sklearn.preprocessing import StandardScaler, LabelEncoder
27
  from sklearn.pipeline import Pipeline
28
  from sklearn.manifold import TSNE
29
  from sklearn.model_selection import train_test_split
30
  from sklearn.utils import resample
31
+ from sklearn.base import BaseEstimator
32
 
33
  import umap
34
 
35
  import requests
36
+ from requests.exceptions import HTTPError
37
  from Bio import Entrez
38
  from Bio import SeqIO
39
  from tqdm import tqdm
 
243
  tsne_plot(x, labels, title=f't-SNE - {model_name}')
244
  plot_umap(x, labels, title=f'UMAP - {model_name}')
245
 
246
+ def evaluate(model: BaseEstimator,
247
+ x_test: np.ndarray,
248
+ y_test: np.ndarray) -> dict:
249
 
250
  """
251
  Evaluates a classification model on test data and computes performance metrics.
 
264
  """
265
 
266
  result = {}
267
+ y_pred = model.predict(x_test) # type: ignore
268
 
269
  result['Accuracy'] = accuracy_score(y_test, y_pred)
270
  result['Recall'] = recall_score(y_test, y_pred, average = 'weighted')
 
272
  result['F1'] = f1_score(y_test, y_pred, average='weighted')
273
 
274
  pprint(result)
 
 
275
  return result
276
 
277
 
 
279
  def train_rf(title: str,
280
  x: np.ndarray,
281
  y : np.ndarray,
282
+ params: dict) -> tuple[RandomForestClassifier, dict, LabelEncoder]:
283
+
284
+ """
285
+ Trains a RandomForestClassifier on the provided data, evaluates its performance, and displays results.
286
+ Args:
287
+ title (str): Title for the confusion matrix plot.
288
+ x (np.ndarray): Feature matrix for training and testing.
289
+ y (np.ndarray): Target labels corresponding to the feature matrix.
290
+ params (dict): Parameters to initialize the RandomForestClassifier.
291
+ Returns:
292
+ tuple[RandomForestClassifier, dict, LabelEncoder]:
293
+ - Trained RandomForestClassifier instance,
294
+ - Evaluation metrics as a dictionary,
295
+ - Fitted LabelEncoder for label transformations.
296
+ Side Effects:
297
+ - Prints a classification report to stdout.
298
+ - Displays a confusion matrix plot.
299
+ """
300
 
301
+ le = LabelEncoder()
302
+ y_encoded = le.fit_transform(y)
303
 
304
  x_train, x_test, y_train, y_test = train_test_split(x, y_encoded, test_size=0.33, stratify=y_encoded, random_state=42)
305
 
 
314
 
315
  evaluation = evaluate(classifier, x_test, y_test)
316
 
317
+ print(classification_report(y_test,
318
+ y_pred,
319
+ zero_division=0,
320
+ target_names = le.classes_))
321
+
322
+ y_pred_str = le.inverse_transform(y_pred)
323
+ y_test_str = le.inverse_transform(y_test)
324
 
325
+ confusion(title=title, y_true=y_test_str, y_pred=y_pred_str)
326
 
327
+ return classifier, evaluation, le
328
 
329
+ def train_svm(title: str, x: np.ndarray, y: np.ndarray, params: dict) -> tuple[Pipeline, dict, LabelEncoder]:
330
  """
331
+ Trains an SVM classifier using the provided data and parameters, evaluates its performance, and returns the trained pipeline, evaluation metrics, and label encoder.
332
 
333
  Args:
334
  title (str): Title for the confusion matrix plot.
335
+ x (np.ndarray): Feature matrix for training and testing.
336
+ y (np.ndarray): Target labels corresponding to the feature matrix.
337
+ params (dict): Dictionary of parameters for the SVM classifier. SVM-specific parameters should be prefixed with 'svm__'.
338
 
339
  Returns:
340
+ tuple[Pipeline, dict, LabelEncoder]:
341
+ - Trained scikit-learn Pipeline object containing the scaler and SVM.
342
+ - Dictionary with evaluation metrics from the `evaluate` function.
343
+ - Fitted LabelEncoder instance for encoding and decoding labels.
344
+
345
+ Side Effects:
346
+ - Displays a confusion matrix plot using the provided title.
347
+ - Prints a classification report to the standard output.
348
  """
349
+
350
+ le = LabelEncoder()
351
+ y_encoded = le.fit_transform(y)
352
+
353
  x_train, x_test, y_train, y_test = train_test_split(
354
+ x, y_encoded, test_size=0.33, stratify=y_encoded, random_state=42
355
  )
356
 
357
  svc_params = {k.replace('svm__', ''): v for k, v in params.items() if k.startswith('svm__')}
358
  pipeline = Pipeline([
359
  ('scaler', StandardScaler()),
360
+ ('svm', svm.SVC(**svc_params, probability = True))
361
  ])
362
 
363
  pipeline.fit(x_train, y_train)
364
 
365
  y_pred = pipeline.predict(x_test)
366
 
367
+ evaluation = evaluate(model=pipeline, x_test=x_test, y_test=y_test)
368
 
369
+ y_pred_str = le.inverse_transform(y_pred)
370
+ y_test_str = le.inverse_transform(y_test)
 
371
 
372
+ confusion(title=title, y_true=y_test_str, y_pred=y_pred_str)
373
 
 
374
 
375
+ print(classification_report(y_test, y_pred, zero_division=0, target_names = le.classes_))
376
 
377
+ return pipeline, evaluation, le
378
 
379
+
380
+ def randomSVM(x: np.ndarray, y: np.ndarray) -> dict:
381
+
382
+ le = LabelEncoder()
383
+ y_encoded = le.fit_transform(y)
384
+ x_train, _, y_train, _ = train_test_split(x,
385
+ y_encoded,
386
  test_size=0.33,
387
+ stratify=y_encoded,
388
+ random_state=42)
389
 
390
  pipeline = Pipeline([('scaler', StandardScaler()),
391
  ('svm', svm.SVC())])
 
409
  n_iter=50,
410
  scoring='f1_weighted',
411
  cv=3,
412
+ verbose=1,
413
  random_state=42,
414
  n_jobs=-1
415
  )
416
 
417
+ random_search.fit(x_train, y_train)
418
+ random_search.best_params_['svm__probability'] = True
419
 
420
  pprint(random_search.best_params_)
421
 
422
  return random_search.best_params_
423
 
424
+ def randomSearch(x: np.ndarray, y: np.ndarray) -> dict:
425
 
426
+ le = LabelEncoder()
427
+ y_encoded = le.fit_transform(y)
428
+
429
+ x_train, _, y_train, _ = train_test_split(x, y_encoded, test_size=0.33, stratify=y_encoded, random_state=42)
430
  classifier : RandomForestClassifier = RandomForestClassifier(random_state=42)
431
 
 
 
 
 
 
 
432
  param_grid = {
433
  'n_estimators': list(np.arange(500,4000, 400)),
434
  'max_depth': [None, 10, 20, 30, 40, 50],
 
446
  n_iter= 50,
447
  scoring = 'f1_weighted',
448
  cv = 3,
449
+ verbose = 1,
450
  n_jobs = -1)
451
 
452
+ rf_random.fit(X = x_train, y = y_train)
453
 
454
  print('Best Params')
455
  pprint(rf_random.best_params_)
 
576
  try:
577
  sequence = fetch_uniprot_sequence(swiss_id)
578
  except HTTPError as e:
579
+ print(f"Warning: SwissProt fetch failed for {swiss_id} with HTTP {e}")
580
  sequence = None
581
 
582
  # Try RefSeq if no SwissProt
 
584
  try:
585
  sequence = fetch_refseq_sequence(row['Refseq_Accession'])
586
  except HTTPError as e:
587
+ print(f"Warning: RefSeq fetch failed for {row['Refseq_Accession']} with HTTP {e}")
588
  sequence = None
589
 
590
  # Try Other_Accession if still no sequence
 
592
  try:
593
  sequence = fetch_refseq_sequence(row['Other_Accession'])
594
  except HTTPError as e:
595
+ print(f"Warning: RefSeq fetch failed for {row['Other_Accession']} with HTTP {e}")
596
  sequence = None
597
 
598
  return idx, sequence