neecat commited on
Commit
57d41d5
·
1 Parent(s): 41e328e

add modified files

Browse files
src/app.py CHANGED
@@ -24,7 +24,7 @@ def compute_saliency_map(model, input_tensor, method):
24
  input_tensor = input_tensor.to(config.DEVICE)
25
  input_tensor.requires_grad_()
26
 
27
- # Get prediction
28
  output = model(input_tensor)
29
  pred_class = output.argmax(dim=1).item()
30
  confidence = torch.softmax(output, dim=1)[0][pred_class].item()
@@ -35,7 +35,12 @@ def compute_saliency_map(model, input_tensor, method):
35
  elif method == "smoothgrad":
36
  attr = NoiseTunnel(Saliency(model))
37
  attributions = attr.attribute(
38
- input_tensor, nt_type="smoothgrad", target=pred_class, nt_samples=20, stdevs=0.2)
 
 
 
 
 
39
  elif method == "guided":
40
  attr = GuidedBackprop(model)
41
  attributions = attr.attribute(input_tensor, target=pred_class)
@@ -43,7 +48,7 @@ def compute_saliency_map(model, input_tensor, method):
43
  raise ValueError("Unsupported method")
44
 
45
  saliency = attributions.squeeze().abs().cpu().detach().numpy()
46
- saliency = np.max(saliency, axis=0) # to grayscale
47
 
48
  return pred_class, confidence, saliency
49
 
@@ -68,7 +73,7 @@ def run_saliency(model, input_tensor):
68
 
69
  output[0, pred_class].backward()
70
  saliency = input_tensor.grad.abs().squeeze().cpu().numpy()
71
- saliency = np.max(saliency, axis=0) # convert to grayscale
72
  return pred_class, confidence, saliency
73
 
74
 
@@ -77,7 +82,7 @@ def get_saliency_figure(input_tensor, saliency_map):
77
  saliency_map /= saliency_map.max() + 1e-10
78
 
79
  img_np = input_tensor.squeeze().detach().cpu().numpy()
80
- img_np = np.transpose(img_np, (1, 2, 0)) # C,H,W → H,W,C
81
  img_np = (img_np * 0.5 + 0.5).clip(0, 1)
82
 
83
  saliency_rgb = np.stack([saliency_map]*3, axis=-1)
 
24
  input_tensor = input_tensor.to(config.DEVICE)
25
  input_tensor.requires_grad_()
26
 
27
+
28
  output = model(input_tensor)
29
  pred_class = output.argmax(dim=1).item()
30
  confidence = torch.softmax(output, dim=1)[0][pred_class].item()
 
35
  elif method == "smoothgrad":
36
  attr = NoiseTunnel(Saliency(model))
37
  attributions = attr.attribute(
38
+ input_tensor,
39
+ nt_type="smoothgrad",
40
+ target=pred_class,
41
+ nt_samples=config.SMOOTHGRAD_SAMPLES,
42
+ stdevs=config.SMOOTHGRAD_STDEV
43
+ )
44
  elif method == "guided":
45
  attr = GuidedBackprop(model)
46
  attributions = attr.attribute(input_tensor, target=pred_class)
 
48
  raise ValueError("Unsupported method")
49
 
50
  saliency = attributions.squeeze().abs().cpu().detach().numpy()
51
+ saliency = np.max(saliency, axis=0)
52
 
53
  return pred_class, confidence, saliency
54
 
 
73
 
74
  output[0, pred_class].backward()
75
  saliency = input_tensor.grad.abs().squeeze().cpu().numpy()
76
+ saliency = np.max(saliency, axis=0)
77
  return pred_class, confidence, saliency
78
 
79
 
 
82
  saliency_map /= saliency_map.max() + 1e-10
83
 
84
  img_np = input_tensor.squeeze().detach().cpu().numpy()
85
+ img_np = np.transpose(img_np, (1, 2, 0))
86
  img_np = (img_np * 0.5 + 0.5).clip(0, 1)
87
 
88
  saliency_rgb = np.stack([saliency_map]*3, axis=-1)
src/config.py CHANGED
@@ -19,3 +19,38 @@ DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
19
 
20
 
21
  NUM_WORKERS = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  NUM_WORKERS = 2
22
+
23
+
24
+ TUNING_EPOCHS = 5
25
+ TUNING_TRIALS = 10
26
+ TUNING_BATCH_SIZE = 32
27
+
28
+
29
+ LR_SCHEDULER_PATIENCE = 2
30
+ LR_SCHEDULER_FACTOR = 0.5
31
+
32
+
33
+ WEIGHT_DECAY = 1e-4
34
+ DROPOUT_RATE = 0.3
35
+
36
+
37
+ DATA_AUG_ROTATION = 15
38
+ DATA_AUG_COLOR_JITTER = 0.1
39
+ DATA_AUG_TRANSLATE = 0.1
40
+ DATA_AUG_SCALE = (0.8, 1.0)
41
+
42
+
43
+ GRAD_CLIP_VALUE = 1.0
44
+
45
+
46
+ SALIENCY_METHODS = ["saliency", "smoothgrad", "guided"]
47
+ SMOOTHGRAD_SAMPLES = 20
48
+ SMOOTHGRAD_STDEV = 0.2
49
+
50
+
51
+ INFERENCE_DIR = os.path.join(DATA_DIR, "inference_test")
52
+
53
+
54
+ os.makedirs(LOG_DIR, exist_ok=True)
55
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
56
+ os.makedirs(INFERENCE_DIR, exist_ok=True)
src/data_loader.py CHANGED
@@ -1,13 +1,22 @@
1
  import os
 
2
  from torchvision import datasets, transforms
3
  from torch.utils.data import DataLoader
 
4
 
5
- def get_transforms(image_size=224):
6
  train_transforms = transforms.Compose([
7
  transforms.Resize((image_size, image_size)),
8
  transforms.RandomHorizontalFlip(),
9
- transforms.RandomRotation(10),
10
- transforms.ColorJitter(brightness=0.2, contrast=0.2),
 
 
 
 
 
 
 
11
  transforms.ToTensor(),
12
  transforms.Normalize([0.5]*3, [0.5]*3)
13
  ])
@@ -20,21 +29,31 @@ def get_transforms(image_size=224):
20
 
21
  return train_transforms, val_test_transforms
22
 
23
- def get_dataloaders(data_dir, batch_size=32, image_size=224, num_workers=2):
24
  train_transforms, val_test_transforms = get_transforms(image_size)
25
 
26
  train_dir = os.path.join(data_dir, 'train')
27
  val_dir = os.path.join(data_dir, 'val')
28
  test_dir = os.path.join(data_dir, 'test')
29
 
 
 
 
 
 
30
  train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
31
  val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)
32
  test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms)
33
 
 
 
 
 
34
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
35
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
36
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
37
 
38
  class_names = train_dataset.classes
 
39
 
40
  return train_loader, val_loader, test_loader, class_names
 
1
  import os
2
+ import logging
3
  from torchvision import datasets, transforms
4
  from torch.utils.data import DataLoader
5
+ from src import config
6
 
7
+ def get_transforms(image_size=config.IMAGE_SIZE):
8
  train_transforms = transforms.Compose([
9
  transforms.Resize((image_size, image_size)),
10
  transforms.RandomHorizontalFlip(),
11
+ transforms.RandomRotation(config.DATA_AUG_ROTATION),
12
+ transforms.ColorJitter(
13
+ brightness=config.DATA_AUG_COLOR_JITTER,
14
+ contrast=config.DATA_AUG_COLOR_JITTER,
15
+ saturation=config.DATA_AUG_COLOR_JITTER,
16
+ hue=config.DATA_AUG_COLOR_JITTER
17
+ ),
18
+ transforms.RandomAffine(degrees=0, translate=(config.DATA_AUG_TRANSLATE, config.DATA_AUG_TRANSLATE)),
19
+ transforms.RandomResizedCrop(image_size, scale=config.DATA_AUG_SCALE),
20
  transforms.ToTensor(),
21
  transforms.Normalize([0.5]*3, [0.5]*3)
22
  ])
 
29
 
30
  return train_transforms, val_test_transforms
31
 
32
+ def get_dataloaders(data_dir, batch_size=config.BATCH_SIZE, image_size=config.IMAGE_SIZE, num_workers=config.NUM_WORKERS):
33
  train_transforms, val_test_transforms = get_transforms(image_size)
34
 
35
  train_dir = os.path.join(data_dir, 'train')
36
  val_dir = os.path.join(data_dir, 'val')
37
  test_dir = os.path.join(data_dir, 'test')
38
 
39
+ logging.info(f"Loading datasets from: {data_dir}")
40
+ logging.info(f"Train directory: {train_dir}")
41
+ logging.info(f"Validation directory: {val_dir}")
42
+ logging.info(f"Test directory: {test_dir}")
43
+
44
  train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
45
  val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)
46
  test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms)
47
 
48
+ logging.info(f"Train dataset size: {len(train_dataset)}")
49
+ logging.info(f"Validation dataset size: {len(val_dataset)}")
50
+ logging.info(f"Test dataset size: {len(test_dataset)}")
51
+
52
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
53
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
54
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
55
 
56
  class_names = train_dataset.classes
57
+ logging.info(f"Classes: {class_names}")
58
 
59
  return train_loader, val_loader, test_loader, class_names
src/ensemble.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.model import TrashNetClassifier
4
+ from torchvision.models import resnet18, efficientnet_b0
5
+
6
+ class EnsembleModel(nn.Module):
7
+ def __init__(self, num_classes=6):
8
+ super(EnsembleModel, self).__init__()
9
+
10
+
11
+ self.model1 = TrashNetClassifier(num_classes=num_classes)
12
+
13
+
14
+ self.model2 = resnet18(pretrained=True)
15
+ self.model2.fc = nn.Linear(self.model2.fc.in_features, num_classes)
16
+
17
+
18
+ self.model3 = efficientnet_b0(pretrained=True)
19
+ self.model3.classifier[1] = nn.Linear(self.model3.classifier[1].in_features, num_classes)
20
+
21
+ def forward(self, x):
22
+
23
+ out1 = self.model1(x)
24
+ out2 = self.model2(x)
25
+ out3 = self.model3(x)
26
+
27
+
28
+ return (out1 + out2 + out3) / 3
src/hyperparameter_tuning.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import random
6
+ from src.model import TrashNetClassifier
7
+ from src.data_loader import get_dataloaders
8
+ from src import config
9
+
10
+
11
+ import logging
12
+ import time
13
+ from datetime import datetime
14
+ import os
15
+
16
+
17
+ def setup_tuning_logging(log_dir):
18
+ os.makedirs(log_dir, exist_ok=True)
19
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
20
+ log_file = os.path.join(log_dir, f"hyperparameter_tuning_{timestamp}.log")
21
+
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(levelname)s - %(message)s',
25
+ handlers=[
26
+ logging.FileHandler(log_file),
27
+ logging.StreamHandler()
28
+ ]
29
+ )
30
+ return log_file
31
+
32
+ def train_model_for_validation(model, train_loader, val_loader, lr, weight_decay, device, epochs=config.TUNING_EPOCHS):
33
+ model = model.to(device)
34
+
35
+ criterion = nn.CrossEntropyLoss()
36
+ optimizer = optim.Adam(
37
+ model.parameters(),
38
+ lr=lr,
39
+ weight_decay=weight_decay
40
+ )
41
+
42
+ best_val_acc = 0.0
43
+
44
+ logging.info(f"Starting validation training with lr={lr}, weight_decay={weight_decay}")
45
+
46
+ for epoch in range(epochs):
47
+
48
+ model.train()
49
+ running_loss, running_acc = 0.0, 0.0
50
+ for batch_idx, (images, labels) in enumerate(train_loader):
51
+ if batch_idx % 20 == 0:
52
+ logging.info(f" Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}")
53
+
54
+ images, labels = images.to(device), labels.to(device)
55
+ optimizer.zero_grad()
56
+ outputs = model(images)
57
+ loss = criterion(outputs, labels)
58
+ loss.backward()
59
+ optimizer.step()
60
+
61
+ preds = torch.argmax(outputs, dim=1)
62
+ acc = (preds == labels).float().mean()
63
+ running_loss += loss.item()
64
+ running_acc += acc.item()
65
+
66
+ train_loss = running_loss / len(train_loader)
67
+ train_acc = running_acc / len(train_loader)
68
+
69
+
70
+ model.eval()
71
+ val_loss, val_acc = 0.0, 0.0
72
+ with torch.no_grad():
73
+ for images, labels in val_loader:
74
+ images, labels = images.to(device), labels.to(device)
75
+ outputs = model(images)
76
+ loss = criterion(outputs, labels)
77
+
78
+ preds = torch.argmax(outputs, dim=1)
79
+ acc = (preds == labels).float().mean()
80
+ val_loss += loss.item()
81
+ val_acc += acc.item()
82
+
83
+ val_loss /= len(val_loader)
84
+ val_acc /= len(val_loader)
85
+
86
+ logging.info(f" Epoch {epoch+1}/{epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
87
+
88
+ if val_acc > best_val_acc:
89
+ best_val_acc = val_acc
90
+ logging.info(f" New best validation accuracy: {best_val_acc:.4f}")
91
+
92
+ return best_val_acc
93
+
94
+ def run_hyperparameter_search():
95
+
96
+ log_file = setup_tuning_logging(config.LOG_DIR)
97
+ logging.info(f"Hyperparameter tuning logs will be saved to: {log_file}")
98
+
99
+ device = torch.device(config.DEVICE)
100
+ logging.info(f"Using device: {device}")
101
+
102
+
103
+ logging.info("Loading datasets...")
104
+ train_loader, val_loader, _, class_names = get_dataloaders(
105
+ data_dir=config.DATA_DIR,
106
+ batch_size=config.TUNING_BATCH_SIZE,
107
+ image_size=config.IMAGE_SIZE,
108
+ num_workers=config.NUM_WORKERS
109
+ )
110
+
111
+
112
+ learning_rates = [1e-5, 1e-4, 5e-4, 1e-3]
113
+ weight_decays = [1e-5, 1e-4, 1e-3]
114
+
115
+
116
+ num_trials = config.TUNING_TRIALS
117
+
118
+ best_acc = 0.0
119
+ best_config = {"lr": 0, "weight_decay": 0}
120
+
121
+ logging.info("Starting hyperparameter search...")
122
+ logging.info(f"Number of trials: {num_trials}")
123
+ logging.info(f"Learning rates to try: {learning_rates}")
124
+ logging.info(f"Weight decays to try: {weight_decays}")
125
+
126
+ start_time = time.time()
127
+
128
+ for trial in range(num_trials):
129
+ trial_start = time.time()
130
+
131
+ lr = random.choice(learning_rates)
132
+ weight_decay = random.choice(weight_decays)
133
+
134
+ logging.info(f"\nTrial {trial+1}/{num_trials}")
135
+ logging.info(f"Testing lr={lr}, weight_decay={weight_decay}")
136
+
137
+
138
+ model = TrashNetClassifier(num_classes=len(class_names))
139
+
140
+
141
+ val_acc = train_model_for_validation(
142
+ model=model,
143
+ train_loader=train_loader,
144
+ val_loader=val_loader,
145
+ lr=lr,
146
+ weight_decay=weight_decay,
147
+ device=device
148
+ )
149
+
150
+ trial_time = time.time() - trial_start
151
+ logging.info(f"Trial {trial+1} completed in {trial_time:.2f}s")
152
+ logging.info(f"Validation accuracy: {val_acc:.4f}")
153
+
154
+
155
+ if val_acc > best_acc:
156
+ best_acc = val_acc
157
+ best_config = {"lr": lr, "weight_decay": weight_decay}
158
+ logging.info(f"New best config found!")
159
+
160
+ total_time = time.time() - start_time
161
+ logging.info(f"\nHyperparameter search completed in {total_time:.2f}s")
162
+ logging.info(f"Best config: lr={best_config['lr']}, weight_decay={best_config['weight_decay']}")
163
+ logging.info(f"Best validation accuracy: {best_acc:.4f}")
164
+
165
+ return best_config
src/model.py CHANGED
@@ -2,25 +2,29 @@ import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
4
  from src import config
 
 
5
 
6
  class TrashNetClassifier(nn.Module):
7
- def __init__(self):
8
  super(TrashNetClassifier, self).__init__()
 
9
 
10
- self.backbone = models.mobilenet_v2(pretrained=True)
11
-
12
  if config.FREEZE_BACKBONE:
13
- for param in self.backbone.features.parameters():
14
  param.requires_grad = False
15
-
16
- self.backbone.classifier = nn.Sequential(
17
- nn.Dropout(p=0.3),
18
- nn.Linear(self.backbone.last_channel, 256),
19
- nn.ReLU(),
20
- nn.BatchNorm1d(256),
21
- nn.Dropout(p=0.3),
22
- nn.Linear(256, config.NUM_CLASSES)
23
  )
24
 
25
  def forward(self, x):
26
- return self.backbone(x)
 
 
 
2
  import torch.nn as nn
3
  import torchvision.models as models
4
  from src import config
5
+ from torchvision.models import mobilenet_v2
6
+
7
 
8
  class TrashNetClassifier(nn.Module):
9
+ def __init__(self, num_classes=config.NUM_CLASSES):
10
  super(TrashNetClassifier, self).__init__()
11
+ self.backbone = mobilenet_v2(pretrained=True)
12
 
13
+
 
14
  if config.FREEZE_BACKBONE:
15
+ for param in list(self.backbone.parameters())[:-8]:
16
  param.requires_grad = False
17
+
18
+
19
+ in_features = self.backbone.classifier[1].in_features
20
+ self.backbone.classifier = nn.Identity()
21
+
22
+ self.classifier = nn.Sequential(
23
+ nn.Dropout(config.DROPOUT_RATE),
24
+ nn.Linear(in_features, num_classes)
25
  )
26
 
27
  def forward(self, x):
28
+ x = self.backbone(x)
29
+ x = self.classifier(x)
30
+ return x
src/predict.py CHANGED
@@ -19,7 +19,7 @@ def preprocess_image(image_path, image_size):
19
  transforms.Normalize([0.5]*3, [0.5]*3)
20
  ])
21
  image = Image.open(image_path).convert("RGB")
22
- return transform(image).unsqueeze(0) # [1, C, H, W]
23
 
24
  def predict_image(model, image_tensor, class_names, device):
25
  image_tensor = image_tensor.to(device)
 
19
  transforms.Normalize([0.5]*3, [0.5]*3)
20
  ])
21
  image = Image.open(image_path).convert("RGB")
22
+ return transform(image).unsqueeze(0)
23
 
24
  def predict_image(model, image_tensor, class_names, device):
25
  image_tensor = image_tensor.to(device)
src/train.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
@@ -5,16 +8,39 @@ from src import config
5
  import time
6
  from torch.utils.tensorboard import SummaryWriter
7
 
 
8
  def calculate_accuracy(y_pred, y_true):
9
  preds = torch.argmax(y_pred, dim=1)
10
  correct = (preds == y_true).sum().item()
11
  return correct / len(y_true)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def train_one_epoch(model, dataloader, criterion, optimizer, device):
14
  model.train()
15
  running_loss, running_acc = 0.0, 0.0
 
 
 
 
 
 
16
 
17
- for images, labels in dataloader:
18
  images, labels = images.to(device), labels.to(device)
19
 
20
  optimizer.zero_grad()
@@ -23,6 +49,10 @@ def train_one_epoch(model, dataloader, criterion, optimizer, device):
23
  acc = calculate_accuracy(outputs, labels)
24
 
25
  loss.backward()
 
 
 
 
26
  optimizer.step()
27
 
28
  running_loss += loss.item()
@@ -30,54 +60,91 @@ def train_one_epoch(model, dataloader, criterion, optimizer, device):
30
 
31
  return running_loss / len(dataloader), running_acc / len(dataloader)
32
 
33
- def validate(model, dataloader, criterion, device):
34
- model.eval()
35
- val_loss, val_acc = 0.0, 0.0
36
 
37
- with torch.no_grad():
38
- for images, labels in dataloader:
39
- images, labels = images.to(device), labels.to(device)
40
- outputs = model(images)
41
- loss = criterion(outputs, labels)
42
- acc = calculate_accuracy(outputs, labels)
43
 
44
- val_loss += loss.item()
45
- val_acc += acc
46
 
47
- return val_loss / len(dataloader), val_acc / len(dataloader)
 
 
 
 
 
 
48
 
49
- def train_model(model, train_loader, val_loader, epochs=config.EPOCHS, lr=config.LEARNING_RATE, device=config.DEVICE):
50
  model = model.to(device)
51
- optimizer = optim.Adam(model.parameters(), lr=lr)
 
 
 
 
 
 
 
 
 
 
52
  criterion = nn.CrossEntropyLoss()
53
  best_val_acc = 0.0
54
-
55
-
56
 
57
  run_name = time.strftime("run_%Y%m%d-%H%M")
58
  log_dir = f"{config.LOG_DIR}/{run_name}"
59
  writer = SummaryWriter(log_dir=log_dir)
60
 
61
- print(f"Training on: {device.upper()}\n")
62
 
63
  for epoch in range(epochs):
64
- train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
 
 
 
 
 
 
65
  val_loss, val_acc = validate(model, val_loader, criterion, device)
66
 
67
- print(f"Epoch {epoch+1}/{epochs}")
68
- print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
69
- print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
70
-
 
 
 
 
 
 
 
71
  writer.add_scalar("Loss/train", train_loss, epoch)
72
  writer.add_scalar("Loss/val", val_loss, epoch)
73
  writer.add_scalar("Accuracy/train", train_acc, epoch)
74
  writer.add_scalar("Accuracy/val", val_acc, epoch)
75
 
76
-
77
  if val_acc > best_val_acc:
78
  best_val_acc = val_acc
79
  torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
80
- print("Model saved!\n")
81
 
82
  writer.close()
83
- print("Training complete. Best Val Acc: {:.2f}%".format(best_val_acc * 100))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import os
3
+ import logging
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
 
8
  import time
9
  from torch.utils.tensorboard import SummaryWriter
10
 
11
+
12
  def calculate_accuracy(y_pred, y_true):
13
  preds = torch.argmax(y_pred, dim=1)
14
  correct = (preds == y_true).sum().item()
15
  return correct / len(y_true)
16
 
17
+
18
+ def setup_logging(log_dir):
19
+ os.makedirs(log_dir, exist_ok=True)
20
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
21
+ log_file = os.path.join(log_dir, f"training_{timestamp}.log")
22
+
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(levelname)s - %(message)s',
26
+ handlers=[
27
+ logging.FileHandler(log_file),
28
+ logging.StreamHandler()
29
+ ]
30
+ )
31
+ return log_file
32
+
33
+
34
  def train_one_epoch(model, dataloader, criterion, optimizer, device):
35
  model.train()
36
  running_loss, running_acc = 0.0, 0.0
37
+ batch_count = len(dataloader)
38
+
39
+ logging.info(f"Training on {batch_count} batches")
40
+ for batch_idx, (images, labels) in enumerate(dataloader):
41
+ if batch_idx % 10 == 0:
42
+ logging.info(f" Batch {batch_idx}/{batch_count}")
43
 
 
44
  images, labels = images.to(device), labels.to(device)
45
 
46
  optimizer.zero_grad()
 
49
  acc = calculate_accuracy(outputs, labels)
50
 
51
  loss.backward()
52
+
53
+ torch.nn.utils.clip_grad_norm_(
54
+ model.parameters(), max_norm=config.GRAD_CLIP_VALUE)
55
+
56
  optimizer.step()
57
 
58
  running_loss += loss.item()
 
60
 
61
  return running_loss / len(dataloader), running_acc / len(dataloader)
62
 
 
 
 
63
 
64
+ def train_model(model, train_loader, val_loader, epochs=config.EPOCHS, lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, device=config.DEVICE):
 
 
 
 
 
65
 
66
+ log_file = setup_logging(config.LOG_DIR)
67
+ logging.info(f"Training logs will be saved to: {log_file}")
68
 
69
+ logging.info(f"Training configuration:")
70
+ logging.info(f" Epochs: {epochs}")
71
+ logging.info(f" Learning rate: {lr}")
72
+ logging.info(f" Weight decay: {weight_decay}")
73
+ logging.info(f" Device: {device}")
74
+ logging.info(f" Batch size: {config.BATCH_SIZE}")
75
+ logging.info(f" Image size: {config.IMAGE_SIZE}")
76
 
 
77
  model = model.to(device)
78
+ optimizer = optim.Adam(model.parameters(), lr=lr,
79
+ weight_decay=weight_decay)
80
+
81
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
82
+ optimizer,
83
+ mode='max',
84
+ factor=config.LR_SCHEDULER_FACTOR,
85
+ patience=config.LR_SCHEDULER_PATIENCE,
86
+ verbose=True
87
+ )
88
+
89
  criterion = nn.CrossEntropyLoss()
90
  best_val_acc = 0.0
 
 
91
 
92
  run_name = time.strftime("run_%Y%m%d-%H%M")
93
  log_dir = f"{config.LOG_DIR}/{run_name}"
94
  writer = SummaryWriter(log_dir=log_dir)
95
 
96
+ logging.info(f"Training on: {device.upper()}\n")
97
 
98
  for epoch in range(epochs):
99
+ epoch_start_time = time.time()
100
+ logging.info(f"Epoch {epoch+1}/{epochs} started")
101
+
102
+ train_loss, train_acc = train_one_epoch(
103
+ model, train_loader, criterion, optimizer, device)
104
+
105
+ logging.info("Validating...")
106
  val_loss, val_acc = validate(model, val_loader, criterion, device)
107
 
108
+ epoch_time = time.time() - epoch_start_time
109
+
110
+ scheduler.step(val_acc)
111
+
112
+ logging.info(
113
+ f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
114
+ logging.info(
115
+ f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
116
+ logging.info(
117
+ f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
118
+
119
  writer.add_scalar("Loss/train", train_loss, epoch)
120
  writer.add_scalar("Loss/val", val_loss, epoch)
121
  writer.add_scalar("Accuracy/train", train_acc, epoch)
122
  writer.add_scalar("Accuracy/val", val_acc, epoch)
123
 
 
124
  if val_acc > best_val_acc:
125
  best_val_acc = val_acc
126
  torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
127
+ logging.info("Model saved!")
128
 
129
  writer.close()
130
+ logging.info("Training complete. Best Val Acc: {:.2f}%".format(
131
+ best_val_acc * 100))
132
+
133
+ return best_val_acc
134
+
135
+
136
+ def validate(model, dataloader, criterion, device):
137
+ model.eval()
138
+ val_loss, val_acc = 0.0, 0.0
139
+
140
+ with torch.no_grad():
141
+ for images, labels in dataloader:
142
+ images, labels = images.to(device), labels.to(device)
143
+ outputs = model(images)
144
+ loss = criterion(outputs, labels)
145
+ acc = calculate_accuracy(outputs, labels)
146
+
147
+ val_loss += loss.item()
148
+ val_acc += acc
149
+
150
+ return val_loss / len(dataloader), val_acc / len(dataloader)
src/train_with_tuning.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.hyperparameter_tuning import run_hyperparameter_search
3
+ from src.model import TrashNetClassifier
4
+ from src.data_loader import get_dataloaders
5
+ from src.train import train_model
6
+ from src import config
7
+
8
+ if __name__ == "__main__":
9
+
10
+ print("Starting hyperparameter search...")
11
+ best_config = run_hyperparameter_search()
12
+
13
+
14
+ print("\nTraining with best hyperparameters...")
15
+
16
+
17
+ train_loader, val_loader, test_loader, class_names = get_dataloaders(
18
+ data_dir=config.DATA_DIR,
19
+ batch_size=config.BATCH_SIZE,
20
+ image_size=config.IMAGE_SIZE,
21
+ num_workers=config.NUM_WORKERS
22
+ )
23
+
24
+
25
+ model = TrashNetClassifier(num_classes=len(class_names))
26
+
27
+
28
+ train_model(
29
+ model=model,
30
+ train_loader=train_loader,
31
+ val_loader=val_loader,
32
+ epochs=config.EPOCHS,
33
+ lr=best_config["lr"],
34
+ weight_decay=best_config["weight_decay"],
35
+ device=config.DEVICE
36
+ )
37
+
38
+ print("Training complete!")