GilbertKrantz commited on
Commit
975672a
·
1 Parent(s): 48f264e

FEAT: Add Type Handler on Function and Clas

Browse files
Files changed (5) hide show
  1. main.py +19 -7
  2. utils/Comparator.py +20 -2
  3. utils/Evaluator.py +12 -8
  4. utils/ModelCreator.py +20 -15
  5. utils/Trainer.py +18 -11
main.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import torch.nn as nn
12
  import matplotlib.pyplot as plt
13
  from torchvision import transforms, datasets
14
- from torch.utils.data import DataLoader, random_split
15
 
16
  # Import custom modules
17
  sys.path.append("./utils")
@@ -23,7 +23,7 @@ from Trainer import model_train
23
 
24
 
25
  # Set random seeds for reproducibility
26
- def set_seed(seed=42):
27
  """Set seeds for reproducibility."""
28
  random.seed(seed)
29
  np.random.seed(seed)
@@ -35,7 +35,7 @@ def set_seed(seed=42):
35
  torch.backends.cudnn.benchmark = False
36
 
37
 
38
- def get_transform():
39
  """
40
  Get standard data transform for both training and validation/testing.
41
 
@@ -55,7 +55,7 @@ def get_transform():
55
  return transform
56
 
57
 
58
- def load_data(args):
59
  """
60
  Load and prepare datasets from separate directories for training and evaluation.
61
 
@@ -135,7 +135,13 @@ def load_data(args):
135
  return train_loader, val_loader, test_loader, eval_dataset
136
 
137
 
138
- def train_single_model(args, train_loader, val_loader, test_loader, dataset):
 
 
 
 
 
 
139
  """Train a single model specified by the arguments."""
140
 
141
  print(f"Creating {args.model} model...")
@@ -177,7 +183,13 @@ def train_single_model(args, train_loader, val_loader, test_loader, dataset):
177
  print("Training failed. Cannot evaluate on test set.")
178
 
179
 
180
- def compare_multiple_models(args, train_loader, val_loader, test_loader, dataset):
 
 
 
 
 
 
181
  """Compare multiple models."""
182
 
183
  print("Preparing to compare multiple models...")
@@ -217,7 +229,7 @@ def compare_multiple_models(args, train_loader, val_loader, test_loader, dataset
217
  )
218
 
219
 
220
- def main():
221
  """Main function to run the eye disease detection application."""
222
 
223
  # Set up argument parser with example usage
 
11
  import torch.nn as nn
12
  import matplotlib.pyplot as plt
13
  from torchvision import transforms, datasets
14
+ from torch.utils.data import DataLoader, random_split, Dataset
15
 
16
  # Import custom modules
17
  sys.path.append("./utils")
 
23
 
24
 
25
  # Set random seeds for reproducibility
26
+ def set_seed(seed=42) -> None:
27
  """Set seeds for reproducibility."""
28
  random.seed(seed)
29
  np.random.seed(seed)
 
35
  torch.backends.cudnn.benchmark = False
36
 
37
 
38
+ def get_transform() -> transforms.Compose:
39
  """
40
  Get standard data transform for both training and validation/testing.
41
 
 
55
  return transform
56
 
57
 
58
+ def load_data(args) -> tuple:
59
  """
60
  Load and prepare datasets from separate directories for training and evaluation.
61
 
 
135
  return train_loader, val_loader, test_loader, eval_dataset
136
 
137
 
138
+ def train_single_model(
139
+ args,
140
+ train_loader: DataLoader,
141
+ val_loader: DataLoader,
142
+ test_loader: DataLoader,
143
+ dataset: Dataset,
144
+ ) -> None:
145
  """Train a single model specified by the arguments."""
146
 
147
  print(f"Creating {args.model} model...")
 
183
  print("Training failed. Cannot evaluate on test set.")
184
 
185
 
186
+ def compare_multiple_models(
187
+ args,
188
+ train_loader: DataLoader,
189
+ val_loader: DataLoader,
190
+ test_loader: DataLoader,
191
+ dataset: Dataset,
192
+ ) -> None:
193
  """Compare multiple models."""
194
 
195
  print("Preparing to compare multiple models...")
 
229
  )
230
 
231
 
232
+ def main() -> None:
233
  """Main function to run the eye disease detection application."""
234
 
235
  # Set up argument parser with example usage
utils/Comparator.py CHANGED
@@ -1,10 +1,28 @@
 
1
  from Trainer import model_train
2
  from Evaluator import ClassificationEvaluator
3
 
4
 
5
  def compare_models(
6
- models, train_loader, val_loader, test_loader, dataset, epochs=20, names=None
7
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  if names is None:
9
  names = [f"Model {i+1}" for i in range(len(models))]
10
 
 
1
+ from torch.utils.data import DataLoader, Dataset
2
  from Trainer import model_train
3
  from Evaluator import ClassificationEvaluator
4
 
5
 
6
  def compare_models(
7
+ models: list,
8
+ train_loader: DataLoader,
9
+ val_loader: DataLoader,
10
+ test_loader: DataLoader,
11
+ dataset: Dataset,
12
+ epochs: int = 20,
13
+ names: list = None,
14
+ ) -> None:
15
+ """
16
+ Compare multiple models on validation and test datasets.
17
+ Args:
18
+ models (list): List of models to compare.
19
+ train_loader (DataLoader): DataLoader for training data.
20
+ val_loader (DataLoader): DataLoader for validation data.
21
+ test_loader (DataLoader): DataLoader for test data.
22
+ dataset (Dataset): Dataset object containing class names.
23
+ epochs (int): Number of epochs for training.
24
+ names (list): List of model names. If None, default names will be used.
25
+ """
26
  if names is None:
27
  names = [f"Model {i+1}" for i in range(len(models))]
28
 
utils/Evaluator.py CHANGED
@@ -1,4 +1,6 @@
1
  import numpy as np
 
 
2
  import torch
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
@@ -23,7 +25,7 @@ class ClassificationEvaluator:
23
  and generate visualizations for model evaluation.
24
  """
25
 
26
- def __init__(self, class_names):
27
  """
28
  Initialize the evaluator with class names.
29
 
@@ -39,7 +41,7 @@ class ClassificationEvaluator:
39
  return data.cpu().numpy()
40
  return np.array(data)
41
 
42
- def evaluate_model(self, model, test_loader):
43
  """
44
  Evaluate a trained model on test dataset.
45
 
@@ -75,7 +77,7 @@ class ClassificationEvaluator:
75
  results = self.compute_metrics(all_labels, all_preds, all_scores)
76
  return results
77
 
78
- def compute_metrics(self, y_true, y_pred, y_scores, model_name=""):
79
  """
80
  Compute comprehensive classification metrics.
81
 
@@ -135,7 +137,7 @@ class ClassificationEvaluator:
135
  "kappa": kappa,
136
  }
137
 
138
- def plot_roc_curves(self, y_true, y_scores):
139
  """
140
  Plot ROC curves for multi-class classification.
141
 
@@ -197,7 +199,7 @@ class ClassificationEvaluator:
197
 
198
  return roc_auc
199
 
200
- def plot_pr_curves(self, y_true, y_scores):
201
  """
202
  Plot Precision-Recall curves for multi-class classification.
203
 
@@ -263,7 +265,7 @@ class ClassificationEvaluator:
263
 
264
  return avg_precision
265
 
266
- def plot_confusion_matrix(self, y_true, y_pred):
267
  """
268
  Plot confusion matrix.
269
 
@@ -296,7 +298,7 @@ class ClassificationEvaluator:
296
  plt.tight_layout()
297
  plt.show()
298
 
299
- def plot_per_class_accuracy(self, y_true, y_pred):
300
  """
301
  Plot per-class accuracy.
302
 
@@ -328,7 +330,9 @@ class ClassificationEvaluator:
328
 
329
  return per_class_accuracy
330
 
331
- def plot_training_history(self, train_losses, val_losses, train_accs, val_accs):
 
 
332
  """
333
  Plot accuracy and loss curves from training history.
334
 
 
1
  import numpy as np
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
  import torch
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
 
25
  and generate visualizations for model evaluation.
26
  """
27
 
28
+ def __init__(self, class_names: list):
29
  """
30
  Initialize the evaluator with class names.
31
 
 
41
  return data.cpu().numpy()
42
  return np.array(data)
43
 
44
+ def evaluate_model(self, model: nn.Module, test_loader: DataLoader) -> dict:
45
  """
46
  Evaluate a trained model on test dataset.
47
 
 
77
  results = self.compute_metrics(all_labels, all_preds, all_scores)
78
  return results
79
 
80
+ def compute_metrics(self, y_true, y_pred, y_scores, model_name: str = "") -> dict:
81
  """
82
  Compute comprehensive classification metrics.
83
 
 
137
  "kappa": kappa,
138
  }
139
 
140
+ def plot_roc_curves(self, y_true, y_scores) -> dict:
141
  """
142
  Plot ROC curves for multi-class classification.
143
 
 
199
 
200
  return roc_auc
201
 
202
+ def plot_pr_curves(self, y_true, y_scores) -> dict:
203
  """
204
  Plot Precision-Recall curves for multi-class classification.
205
 
 
265
 
266
  return avg_precision
267
 
268
+ def plot_confusion_matrix(self, y_true, y_pred) -> None:
269
  """
270
  Plot confusion matrix.
271
 
 
298
  plt.tight_layout()
299
  plt.show()
300
 
301
+ def plot_per_class_accuracy(self, y_true, y_pred) -> np.ndarray:
302
  """
303
  Plot per-class accuracy.
304
 
 
330
 
331
  return per_class_accuracy
332
 
333
+ def plot_training_history(
334
+ self, train_losses, val_losses, train_accs, val_accs
335
+ ) -> None:
336
  """
337
  Plot accuracy and loss curves from training history.
338
 
utils/ModelCreator.py CHANGED
@@ -1,13 +1,23 @@
1
  import torch
 
2
  import torch.nn as nn
3
  import timm
4
 
5
  # Set device
6
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
 
9
- class EyeDetectionModels:
10
- def __init__(self, num_classes, freeze_layers=True, device=DEVICE):
 
 
 
 
 
 
 
 
 
11
  """
12
  Initialize the EyeDetectionModels class.
13
  This class provides methods to create and configure various deep learning models for eye detection.
@@ -26,7 +36,7 @@ class EyeDetectionModels:
26
 
27
  # Model architecture functions
28
  @staticmethod
29
- def _get_feature_blocks(model):
30
  """
31
  Utility: locate the main feature blocks container in a timm model.
32
  Returns a list-like module of blocks.
@@ -38,14 +48,14 @@ class EyeDetectionModels:
38
  return list(model.children())[:-1]
39
 
40
  @staticmethod
41
- def _freeze_except_last_n(blocks, n):
42
  total = len(blocks)
43
  for idx, block in enumerate(blocks):
44
  requires = idx >= total - n
45
  for p in block.parameters():
46
  p.requires_grad = requires
47
 
48
- def get_model_mobilenetv4(self):
49
  model = timm.create_model(
50
  "mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
51
  )
@@ -62,14 +72,12 @@ class EyeDetectionModels:
62
  )
63
  return model.to(self.device)
64
 
65
- def get_model_levit(self):
66
  model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
67
  if self.freeze_layers:
68
  blocks = self._get_feature_blocks(model)
69
  self._freeze_except_last_n(blocks, 2)
70
  # Attempt to extract in_features from model.head or classifier
71
- head = getattr(model, "head_dist", None) or getattr(model, "classifier", None)
72
- linear = getattr(head, "linear")
73
  in_features = 384
74
  model.head = nn.Sequential(
75
  nn.Linear(in_features, 512),
@@ -85,15 +93,12 @@ class EyeDetectionModels:
85
  )
86
  return model.to(self.device)
87
 
88
- def get_model_efficientvit(self):
89
  model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
90
  if self.freeze_layers:
91
  blocks = self._get_feature_blocks(model)
92
  self._freeze_except_last_n(blocks, 2)
93
  # handle different head naming
94
- head = getattr(model, "head", None)
95
- print(head)
96
- linear = getattr(head, "linear")
97
  in_features = 192
98
  model.head.linear = nn.Sequential(
99
  nn.Linear(in_features, 512),
@@ -103,7 +108,7 @@ class EyeDetectionModels:
103
  )
104
  return model.to(self.device)
105
 
106
- def get_model_gernet(self):
107
  """
108
  Load and configure a GENet (General and Efficient Network) model with customizable classifier.
109
 
@@ -142,7 +147,7 @@ class EyeDetectionModels:
142
  )
143
  return model.to(self.device)
144
 
145
- def get_model_regnetx(self):
146
  """
147
  Load and configure a RegNetX model with customizable classifier.
148
 
 
1
  import torch
2
+ from torch import device
3
  import torch.nn as nn
4
  import timm
5
 
6
  # Set device
7
+ DEVICE = device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
 
10
+ class EyeDetectionModels(object):
11
+ """
12
+ A class to create and configure various deep learning models for eye detection tasks.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ num_classes: int,
18
+ freeze_layers: bool = True,
19
+ device: device = DEVICE,
20
+ ):
21
  """
22
  Initialize the EyeDetectionModels class.
23
  This class provides methods to create and configure various deep learning models for eye detection.
 
36
 
37
  # Model architecture functions
38
  @staticmethod
39
+ def _get_feature_blocks(model: nn.Module) -> nn.ModuleList:
40
  """
41
  Utility: locate the main feature blocks container in a timm model.
42
  Returns a list-like module of blocks.
 
48
  return list(model.children())[:-1]
49
 
50
  @staticmethod
51
+ def _freeze_except_last_n(blocks: nn.ModuleList, n: int) -> None:
52
  total = len(blocks)
53
  for idx, block in enumerate(blocks):
54
  requires = idx >= total - n
55
  for p in block.parameters():
56
  p.requires_grad = requires
57
 
58
+ def get_model_mobilenetv4(self) -> nn.Module:
59
  model = timm.create_model(
60
  "mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
61
  )
 
72
  )
73
  return model.to(self.device)
74
 
75
+ def get_model_levit(self) -> nn.Module:
76
  model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
77
  if self.freeze_layers:
78
  blocks = self._get_feature_blocks(model)
79
  self._freeze_except_last_n(blocks, 2)
80
  # Attempt to extract in_features from model.head or classifier
 
 
81
  in_features = 384
82
  model.head = nn.Sequential(
83
  nn.Linear(in_features, 512),
 
93
  )
94
  return model.to(self.device)
95
 
96
+ def get_model_efficientvit(self) -> nn.Module:
97
  model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
98
  if self.freeze_layers:
99
  blocks = self._get_feature_blocks(model)
100
  self._freeze_except_last_n(blocks, 2)
101
  # handle different head naming
 
 
 
102
  in_features = 192
103
  model.head.linear = nn.Sequential(
104
  nn.Linear(in_features, 512),
 
108
  )
109
  return model.to(self.device)
110
 
111
+ def get_model_gernet(self) -> nn.Module:
112
  """
113
  Load and configure a GENet (General and Efficient Network) model with customizable classifier.
114
 
 
147
  )
148
  return model.to(self.device)
149
 
150
+ def get_model_regnetx(self) -> nn.Module:
151
  """
152
  Load and configure a RegNetX model with customizable classifier.
153
 
utils/Trainer.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import torch.nn.functional as F
4
  import torch.optim as optim
5
  import torch.nn as nn
 
6
  from tqdm import tqdm
7
  import gc
8
 
@@ -11,16 +12,16 @@ from Callback import EarlyStopping
11
 
12
 
13
  def train_model(
14
- model,
15
- criterion,
16
- optimizer,
17
- scheduler,
18
- train_loader,
19
- val_loader,
20
- early_stopping,
21
- epochs=15,
22
- use_ddp=False,
23
- ):
24
  """
25
  Train the model and perform validation using multiple GPUs.
26
  Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
@@ -180,7 +181,13 @@ def train_model(
180
  )
181
 
182
 
183
- def model_train(model, train_loader, val_loader, dataset, epochs=20):
 
 
 
 
 
 
184
  model_name = type(model).__name__
185
  if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
186
  model_name = model.pretrained_cfg["name"]
 
3
  import torch.nn.functional as F
4
  import torch.optim as optim
5
  import torch.nn as nn
6
+ from torch.utils.data import DataLoader, Dataset
7
  from tqdm import tqdm
8
  import gc
9
 
 
12
 
13
 
14
  def train_model(
15
+ model: nn.Module,
16
+ criterion: nn.Module,
17
+ optimizer: optim.Optimizer,
18
+ scheduler: optim.lr_scheduler._LRScheduler,
19
+ train_loader: DataLoader,
20
+ val_loader: DataLoader,
21
+ early_stopping: EarlyStopping,
22
+ epochs: int = 15,
23
+ use_ddp: bool = False,
24
+ ) -> tuple:
25
  """
26
  Train the model and perform validation using multiple GPUs.
27
  Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
 
181
  )
182
 
183
 
184
+ def model_train(
185
+ model: nn.Module,
186
+ train_loader: DataLoader,
187
+ val_loader: DataLoader,
188
+ dataset: Dataset,
189
+ epochs: int = 20,
190
+ ) -> tuple:
191
  model_name = type(model).__name__
192
  if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
193
  model_name = model.pretrained_cfg["name"]