Spaces:
Sleeping
Sleeping
GilbertKrantz commited on
Commit ·
975672a
1
Parent(s): 48f264e
FEAT: Add Type Handler on Function and Clas
Browse files- main.py +19 -7
- utils/Comparator.py +20 -2
- utils/Evaluator.py +12 -8
- utils/ModelCreator.py +20 -15
- utils/Trainer.py +18 -11
main.py
CHANGED
|
@@ -11,7 +11,7 @@ import torch
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
from torchvision import transforms, datasets
|
| 14 |
-
from torch.utils.data import DataLoader, random_split
|
| 15 |
|
| 16 |
# Import custom modules
|
| 17 |
sys.path.append("./utils")
|
|
@@ -23,7 +23,7 @@ from Trainer import model_train
|
|
| 23 |
|
| 24 |
|
| 25 |
# Set random seeds for reproducibility
|
| 26 |
-
def set_seed(seed=42):
|
| 27 |
"""Set seeds for reproducibility."""
|
| 28 |
random.seed(seed)
|
| 29 |
np.random.seed(seed)
|
|
@@ -35,7 +35,7 @@ def set_seed(seed=42):
|
|
| 35 |
torch.backends.cudnn.benchmark = False
|
| 36 |
|
| 37 |
|
| 38 |
-
def get_transform():
|
| 39 |
"""
|
| 40 |
Get standard data transform for both training and validation/testing.
|
| 41 |
|
|
@@ -55,7 +55,7 @@ def get_transform():
|
|
| 55 |
return transform
|
| 56 |
|
| 57 |
|
| 58 |
-
def load_data(args):
|
| 59 |
"""
|
| 60 |
Load and prepare datasets from separate directories for training and evaluation.
|
| 61 |
|
|
@@ -135,7 +135,13 @@ def load_data(args):
|
|
| 135 |
return train_loader, val_loader, test_loader, eval_dataset
|
| 136 |
|
| 137 |
|
| 138 |
-
def train_single_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"""Train a single model specified by the arguments."""
|
| 140 |
|
| 141 |
print(f"Creating {args.model} model...")
|
|
@@ -177,7 +183,13 @@ def train_single_model(args, train_loader, val_loader, test_loader, dataset):
|
|
| 177 |
print("Training failed. Cannot evaluate on test set.")
|
| 178 |
|
| 179 |
|
| 180 |
-
def compare_multiple_models(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
"""Compare multiple models."""
|
| 182 |
|
| 183 |
print("Preparing to compare multiple models...")
|
|
@@ -217,7 +229,7 @@ def compare_multiple_models(args, train_loader, val_loader, test_loader, dataset
|
|
| 217 |
)
|
| 218 |
|
| 219 |
|
| 220 |
-
def main():
|
| 221 |
"""Main function to run the eye disease detection application."""
|
| 222 |
|
| 223 |
# Set up argument parser with example usage
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
from torchvision import transforms, datasets
|
| 14 |
+
from torch.utils.data import DataLoader, random_split, Dataset
|
| 15 |
|
| 16 |
# Import custom modules
|
| 17 |
sys.path.append("./utils")
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
# Set random seeds for reproducibility
|
| 26 |
+
def set_seed(seed=42) -> None:
|
| 27 |
"""Set seeds for reproducibility."""
|
| 28 |
random.seed(seed)
|
| 29 |
np.random.seed(seed)
|
|
|
|
| 35 |
torch.backends.cudnn.benchmark = False
|
| 36 |
|
| 37 |
|
| 38 |
+
def get_transform() -> transforms.Compose:
|
| 39 |
"""
|
| 40 |
Get standard data transform for both training and validation/testing.
|
| 41 |
|
|
|
|
| 55 |
return transform
|
| 56 |
|
| 57 |
|
| 58 |
+
def load_data(args) -> tuple:
|
| 59 |
"""
|
| 60 |
Load and prepare datasets from separate directories for training and evaluation.
|
| 61 |
|
|
|
|
| 135 |
return train_loader, val_loader, test_loader, eval_dataset
|
| 136 |
|
| 137 |
|
| 138 |
+
def train_single_model(
|
| 139 |
+
args,
|
| 140 |
+
train_loader: DataLoader,
|
| 141 |
+
val_loader: DataLoader,
|
| 142 |
+
test_loader: DataLoader,
|
| 143 |
+
dataset: Dataset,
|
| 144 |
+
) -> None:
|
| 145 |
"""Train a single model specified by the arguments."""
|
| 146 |
|
| 147 |
print(f"Creating {args.model} model...")
|
|
|
|
| 183 |
print("Training failed. Cannot evaluate on test set.")
|
| 184 |
|
| 185 |
|
| 186 |
+
def compare_multiple_models(
|
| 187 |
+
args,
|
| 188 |
+
train_loader: DataLoader,
|
| 189 |
+
val_loader: DataLoader,
|
| 190 |
+
test_loader: DataLoader,
|
| 191 |
+
dataset: Dataset,
|
| 192 |
+
) -> None:
|
| 193 |
"""Compare multiple models."""
|
| 194 |
|
| 195 |
print("Preparing to compare multiple models...")
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
|
| 232 |
+
def main() -> None:
|
| 233 |
"""Main function to run the eye disease detection application."""
|
| 234 |
|
| 235 |
# Set up argument parser with example usage
|
utils/Comparator.py
CHANGED
|
@@ -1,10 +1,28 @@
|
|
|
|
|
| 1 |
from Trainer import model_train
|
| 2 |
from Evaluator import ClassificationEvaluator
|
| 3 |
|
| 4 |
|
| 5 |
def compare_models(
|
| 6 |
-
models
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
if names is None:
|
| 9 |
names = [f"Model {i+1}" for i in range(len(models))]
|
| 10 |
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader, Dataset
|
| 2 |
from Trainer import model_train
|
| 3 |
from Evaluator import ClassificationEvaluator
|
| 4 |
|
| 5 |
|
| 6 |
def compare_models(
|
| 7 |
+
models: list,
|
| 8 |
+
train_loader: DataLoader,
|
| 9 |
+
val_loader: DataLoader,
|
| 10 |
+
test_loader: DataLoader,
|
| 11 |
+
dataset: Dataset,
|
| 12 |
+
epochs: int = 20,
|
| 13 |
+
names: list = None,
|
| 14 |
+
) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Compare multiple models on validation and test datasets.
|
| 17 |
+
Args:
|
| 18 |
+
models (list): List of models to compare.
|
| 19 |
+
train_loader (DataLoader): DataLoader for training data.
|
| 20 |
+
val_loader (DataLoader): DataLoader for validation data.
|
| 21 |
+
test_loader (DataLoader): DataLoader for test data.
|
| 22 |
+
dataset (Dataset): Dataset object containing class names.
|
| 23 |
+
epochs (int): Number of epochs for training.
|
| 24 |
+
names (list): List of model names. If None, default names will be used.
|
| 25 |
+
"""
|
| 26 |
if names is None:
|
| 27 |
names = [f"Model {i+1}" for i in range(len(models))]
|
| 28 |
|
utils/Evaluator.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import numpy as np
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import seaborn as sns
|
|
@@ -23,7 +25,7 @@ class ClassificationEvaluator:
|
|
| 23 |
and generate visualizations for model evaluation.
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
def __init__(self, class_names):
|
| 27 |
"""
|
| 28 |
Initialize the evaluator with class names.
|
| 29 |
|
|
@@ -39,7 +41,7 @@ class ClassificationEvaluator:
|
|
| 39 |
return data.cpu().numpy()
|
| 40 |
return np.array(data)
|
| 41 |
|
| 42 |
-
def evaluate_model(self, model, test_loader):
|
| 43 |
"""
|
| 44 |
Evaluate a trained model on test dataset.
|
| 45 |
|
|
@@ -75,7 +77,7 @@ class ClassificationEvaluator:
|
|
| 75 |
results = self.compute_metrics(all_labels, all_preds, all_scores)
|
| 76 |
return results
|
| 77 |
|
| 78 |
-
def compute_metrics(self, y_true, y_pred, y_scores, model_name=""):
|
| 79 |
"""
|
| 80 |
Compute comprehensive classification metrics.
|
| 81 |
|
|
@@ -135,7 +137,7 @@ class ClassificationEvaluator:
|
|
| 135 |
"kappa": kappa,
|
| 136 |
}
|
| 137 |
|
| 138 |
-
def plot_roc_curves(self, y_true, y_scores):
|
| 139 |
"""
|
| 140 |
Plot ROC curves for multi-class classification.
|
| 141 |
|
|
@@ -197,7 +199,7 @@ class ClassificationEvaluator:
|
|
| 197 |
|
| 198 |
return roc_auc
|
| 199 |
|
| 200 |
-
def plot_pr_curves(self, y_true, y_scores):
|
| 201 |
"""
|
| 202 |
Plot Precision-Recall curves for multi-class classification.
|
| 203 |
|
|
@@ -263,7 +265,7 @@ class ClassificationEvaluator:
|
|
| 263 |
|
| 264 |
return avg_precision
|
| 265 |
|
| 266 |
-
def plot_confusion_matrix(self, y_true, y_pred):
|
| 267 |
"""
|
| 268 |
Plot confusion matrix.
|
| 269 |
|
|
@@ -296,7 +298,7 @@ class ClassificationEvaluator:
|
|
| 296 |
plt.tight_layout()
|
| 297 |
plt.show()
|
| 298 |
|
| 299 |
-
def plot_per_class_accuracy(self, y_true, y_pred):
|
| 300 |
"""
|
| 301 |
Plot per-class accuracy.
|
| 302 |
|
|
@@ -328,7 +330,9 @@ class ClassificationEvaluator:
|
|
| 328 |
|
| 329 |
return per_class_accuracy
|
| 330 |
|
| 331 |
-
def plot_training_history(
|
|
|
|
|
|
|
| 332 |
"""
|
| 333 |
Plot accuracy and loss curves from training history.
|
| 334 |
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
import torch
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import seaborn as sns
|
|
|
|
| 25 |
and generate visualizations for model evaluation.
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
def __init__(self, class_names: list):
|
| 29 |
"""
|
| 30 |
Initialize the evaluator with class names.
|
| 31 |
|
|
|
|
| 41 |
return data.cpu().numpy()
|
| 42 |
return np.array(data)
|
| 43 |
|
| 44 |
+
def evaluate_model(self, model: nn.Module, test_loader: DataLoader) -> dict:
|
| 45 |
"""
|
| 46 |
Evaluate a trained model on test dataset.
|
| 47 |
|
|
|
|
| 77 |
results = self.compute_metrics(all_labels, all_preds, all_scores)
|
| 78 |
return results
|
| 79 |
|
| 80 |
+
def compute_metrics(self, y_true, y_pred, y_scores, model_name: str = "") -> dict:
|
| 81 |
"""
|
| 82 |
Compute comprehensive classification metrics.
|
| 83 |
|
|
|
|
| 137 |
"kappa": kappa,
|
| 138 |
}
|
| 139 |
|
| 140 |
+
def plot_roc_curves(self, y_true, y_scores) -> dict:
|
| 141 |
"""
|
| 142 |
Plot ROC curves for multi-class classification.
|
| 143 |
|
|
|
|
| 199 |
|
| 200 |
return roc_auc
|
| 201 |
|
| 202 |
+
def plot_pr_curves(self, y_true, y_scores) -> dict:
|
| 203 |
"""
|
| 204 |
Plot Precision-Recall curves for multi-class classification.
|
| 205 |
|
|
|
|
| 265 |
|
| 266 |
return avg_precision
|
| 267 |
|
| 268 |
+
def plot_confusion_matrix(self, y_true, y_pred) -> None:
|
| 269 |
"""
|
| 270 |
Plot confusion matrix.
|
| 271 |
|
|
|
|
| 298 |
plt.tight_layout()
|
| 299 |
plt.show()
|
| 300 |
|
| 301 |
+
def plot_per_class_accuracy(self, y_true, y_pred) -> np.ndarray:
|
| 302 |
"""
|
| 303 |
Plot per-class accuracy.
|
| 304 |
|
|
|
|
| 330 |
|
| 331 |
return per_class_accuracy
|
| 332 |
|
| 333 |
+
def plot_training_history(
|
| 334 |
+
self, train_losses, val_losses, train_accs, val_accs
|
| 335 |
+
) -> None:
|
| 336 |
"""
|
| 337 |
Plot accuracy and loss curves from training history.
|
| 338 |
|
utils/ModelCreator.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import timm
|
| 4 |
|
| 5 |
# Set device
|
| 6 |
-
DEVICE =
|
| 7 |
|
| 8 |
|
| 9 |
-
class EyeDetectionModels:
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
Initialize the EyeDetectionModels class.
|
| 13 |
This class provides methods to create and configure various deep learning models for eye detection.
|
|
@@ -26,7 +36,7 @@ class EyeDetectionModels:
|
|
| 26 |
|
| 27 |
# Model architecture functions
|
| 28 |
@staticmethod
|
| 29 |
-
def _get_feature_blocks(model):
|
| 30 |
"""
|
| 31 |
Utility: locate the main feature blocks container in a timm model.
|
| 32 |
Returns a list-like module of blocks.
|
|
@@ -38,14 +48,14 @@ class EyeDetectionModels:
|
|
| 38 |
return list(model.children())[:-1]
|
| 39 |
|
| 40 |
@staticmethod
|
| 41 |
-
def _freeze_except_last_n(blocks, n):
|
| 42 |
total = len(blocks)
|
| 43 |
for idx, block in enumerate(blocks):
|
| 44 |
requires = idx >= total - n
|
| 45 |
for p in block.parameters():
|
| 46 |
p.requires_grad = requires
|
| 47 |
|
| 48 |
-
def get_model_mobilenetv4(self):
|
| 49 |
model = timm.create_model(
|
| 50 |
"mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
|
| 51 |
)
|
|
@@ -62,14 +72,12 @@ class EyeDetectionModels:
|
|
| 62 |
)
|
| 63 |
return model.to(self.device)
|
| 64 |
|
| 65 |
-
def get_model_levit(self):
|
| 66 |
model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
|
| 67 |
if self.freeze_layers:
|
| 68 |
blocks = self._get_feature_blocks(model)
|
| 69 |
self._freeze_except_last_n(blocks, 2)
|
| 70 |
# Attempt to extract in_features from model.head or classifier
|
| 71 |
-
head = getattr(model, "head_dist", None) or getattr(model, "classifier", None)
|
| 72 |
-
linear = getattr(head, "linear")
|
| 73 |
in_features = 384
|
| 74 |
model.head = nn.Sequential(
|
| 75 |
nn.Linear(in_features, 512),
|
|
@@ -85,15 +93,12 @@ class EyeDetectionModels:
|
|
| 85 |
)
|
| 86 |
return model.to(self.device)
|
| 87 |
|
| 88 |
-
def get_model_efficientvit(self):
|
| 89 |
model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
|
| 90 |
if self.freeze_layers:
|
| 91 |
blocks = self._get_feature_blocks(model)
|
| 92 |
self._freeze_except_last_n(blocks, 2)
|
| 93 |
# handle different head naming
|
| 94 |
-
head = getattr(model, "head", None)
|
| 95 |
-
print(head)
|
| 96 |
-
linear = getattr(head, "linear")
|
| 97 |
in_features = 192
|
| 98 |
model.head.linear = nn.Sequential(
|
| 99 |
nn.Linear(in_features, 512),
|
|
@@ -103,7 +108,7 @@ class EyeDetectionModels:
|
|
| 103 |
)
|
| 104 |
return model.to(self.device)
|
| 105 |
|
| 106 |
-
def get_model_gernet(self):
|
| 107 |
"""
|
| 108 |
Load and configure a GENet (General and Efficient Network) model with customizable classifier.
|
| 109 |
|
|
@@ -142,7 +147,7 @@ class EyeDetectionModels:
|
|
| 142 |
)
|
| 143 |
return model.to(self.device)
|
| 144 |
|
| 145 |
-
def get_model_regnetx(self):
|
| 146 |
"""
|
| 147 |
Load and configure a RegNetX model with customizable classifier.
|
| 148 |
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from torch import device
|
| 3 |
import torch.nn as nn
|
| 4 |
import timm
|
| 5 |
|
| 6 |
# Set device
|
| 7 |
+
DEVICE = device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
|
| 9 |
|
| 10 |
+
class EyeDetectionModels(object):
|
| 11 |
+
"""
|
| 12 |
+
A class to create and configure various deep learning models for eye detection tasks.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_classes: int,
|
| 18 |
+
freeze_layers: bool = True,
|
| 19 |
+
device: device = DEVICE,
|
| 20 |
+
):
|
| 21 |
"""
|
| 22 |
Initialize the EyeDetectionModels class.
|
| 23 |
This class provides methods to create and configure various deep learning models for eye detection.
|
|
|
|
| 36 |
|
| 37 |
# Model architecture functions
|
| 38 |
@staticmethod
|
| 39 |
+
def _get_feature_blocks(model: nn.Module) -> nn.ModuleList:
|
| 40 |
"""
|
| 41 |
Utility: locate the main feature blocks container in a timm model.
|
| 42 |
Returns a list-like module of blocks.
|
|
|
|
| 48 |
return list(model.children())[:-1]
|
| 49 |
|
| 50 |
@staticmethod
|
| 51 |
+
def _freeze_except_last_n(blocks: nn.ModuleList, n: int) -> None:
|
| 52 |
total = len(blocks)
|
| 53 |
for idx, block in enumerate(blocks):
|
| 54 |
requires = idx >= total - n
|
| 55 |
for p in block.parameters():
|
| 56 |
p.requires_grad = requires
|
| 57 |
|
| 58 |
+
def get_model_mobilenetv4(self) -> nn.Module:
|
| 59 |
model = timm.create_model(
|
| 60 |
"mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
|
| 61 |
)
|
|
|
|
| 72 |
)
|
| 73 |
return model.to(self.device)
|
| 74 |
|
| 75 |
+
def get_model_levit(self) -> nn.Module:
|
| 76 |
model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
|
| 77 |
if self.freeze_layers:
|
| 78 |
blocks = self._get_feature_blocks(model)
|
| 79 |
self._freeze_except_last_n(blocks, 2)
|
| 80 |
# Attempt to extract in_features from model.head or classifier
|
|
|
|
|
|
|
| 81 |
in_features = 384
|
| 82 |
model.head = nn.Sequential(
|
| 83 |
nn.Linear(in_features, 512),
|
|
|
|
| 93 |
)
|
| 94 |
return model.to(self.device)
|
| 95 |
|
| 96 |
+
def get_model_efficientvit(self) -> nn.Module:
|
| 97 |
model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
|
| 98 |
if self.freeze_layers:
|
| 99 |
blocks = self._get_feature_blocks(model)
|
| 100 |
self._freeze_except_last_n(blocks, 2)
|
| 101 |
# handle different head naming
|
|
|
|
|
|
|
|
|
|
| 102 |
in_features = 192
|
| 103 |
model.head.linear = nn.Sequential(
|
| 104 |
nn.Linear(in_features, 512),
|
|
|
|
| 108 |
)
|
| 109 |
return model.to(self.device)
|
| 110 |
|
| 111 |
+
def get_model_gernet(self) -> nn.Module:
|
| 112 |
"""
|
| 113 |
Load and configure a GENet (General and Efficient Network) model with customizable classifier.
|
| 114 |
|
|
|
|
| 147 |
)
|
| 148 |
return model.to(self.device)
|
| 149 |
|
| 150 |
+
def get_model_regnetx(self) -> nn.Module:
|
| 151 |
"""
|
| 152 |
Load and configure a RegNetX model with customizable classifier.
|
| 153 |
|
utils/Trainer.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import torch.optim as optim
|
| 5 |
import torch.nn as nn
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
import gc
|
| 8 |
|
|
@@ -11,16 +12,16 @@ from Callback import EarlyStopping
|
|
| 11 |
|
| 12 |
|
| 13 |
def train_model(
|
| 14 |
-
model,
|
| 15 |
-
criterion,
|
| 16 |
-
optimizer,
|
| 17 |
-
scheduler,
|
| 18 |
-
train_loader,
|
| 19 |
-
val_loader,
|
| 20 |
-
early_stopping,
|
| 21 |
-
epochs=15,
|
| 22 |
-
use_ddp=False,
|
| 23 |
-
):
|
| 24 |
"""
|
| 25 |
Train the model and perform validation using multiple GPUs.
|
| 26 |
Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
|
|
@@ -180,7 +181,13 @@ def train_model(
|
|
| 180 |
)
|
| 181 |
|
| 182 |
|
| 183 |
-
def model_train(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
model_name = type(model).__name__
|
| 185 |
if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
|
| 186 |
model_name = model.pretrained_cfg["name"]
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import torch.optim as optim
|
| 5 |
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
from tqdm import tqdm
|
| 8 |
import gc
|
| 9 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def train_model(
|
| 15 |
+
model: nn.Module,
|
| 16 |
+
criterion: nn.Module,
|
| 17 |
+
optimizer: optim.Optimizer,
|
| 18 |
+
scheduler: optim.lr_scheduler._LRScheduler,
|
| 19 |
+
train_loader: DataLoader,
|
| 20 |
+
val_loader: DataLoader,
|
| 21 |
+
early_stopping: EarlyStopping,
|
| 22 |
+
epochs: int = 15,
|
| 23 |
+
use_ddp: bool = False,
|
| 24 |
+
) -> tuple:
|
| 25 |
"""
|
| 26 |
Train the model and perform validation using multiple GPUs.
|
| 27 |
Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
|
|
|
|
| 181 |
)
|
| 182 |
|
| 183 |
|
| 184 |
+
def model_train(
|
| 185 |
+
model: nn.Module,
|
| 186 |
+
train_loader: DataLoader,
|
| 187 |
+
val_loader: DataLoader,
|
| 188 |
+
dataset: Dataset,
|
| 189 |
+
epochs: int = 20,
|
| 190 |
+
) -> tuple:
|
| 191 |
model_name = type(model).__name__
|
| 192 |
if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
|
| 193 |
model_name = model.pretrained_cfg["name"]
|