fatimaxa commited on
Commit
3d05927
·
verified ·
1 Parent(s): 5762bbf

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +89 -65
train.py CHANGED
@@ -1,26 +1,44 @@
1
  import torch
2
  import torch.nn as nn
3
- from pathlib import Path
4
- from data_prep import test_loader, device
5
  from models.model import PlantCNN
6
  from utils.config import load_config
7
- from clearml import Task, InputModel
8
- import numpy as np
9
- from utils.vis import visualize_preds, plot_cfm
10
  from tqdm.auto import tqdm
11
- import ast
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def evaluate_on_test(model, loader, loss_fn, device, num_imgs):
15
  model.eval()
16
- all_labels = []
17
- all_preds = []
18
  running_loss = 0.0
19
  correct = 0
20
  total = 0
21
- imgs_to_display = []
22
- lbls_to_display = []
23
- prs_to_display = []
24
 
25
  with torch.no_grad():
26
  for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
@@ -36,69 +54,75 @@ def evaluate_on_test(model, loader, loss_fn, device, num_imgs):
36
  correct += (preds==labels).sum().item()
37
  total += labels.size(0)
38
 
39
- all_labels.extend(labels.cpu().numpy())
40
- all_preds.extend(preds.cpu().numpy())
41
-
42
- if len(imgs_to_display) < num_imgs:
43
- remaining = num_imgs - len(imgs_to_display)
44
- for img, lbl, pr in zip(images[:remaining], preds[:remaining], preds[:remaining]):
45
- imgs_to_display.append(img.cpu())
46
- lbls_to_display.append(lbl.item())
47
- prs_to_display.append(pr.item())
48
-
49
- test_loss = running_loss / total
50
- test_acc = correct / total
51
- return test_loss, test_acc, all_labels, all_preds, imgs_to_display, lbls_to_display, prs_to_display
52
-
53
-
54
-
55
 
56
  def main():
 
 
 
 
 
 
 
 
57
  project_name = "GAP_plant_disease_classification"
58
- model_name = "PlantCNN"
59
 
60
- task = Task.init(project_name=project_name, task_name=f"{model_name}_test")
 
 
 
 
 
 
61
  logger = task.get_logger()
62
 
63
- input_model = InputModel(model_id="b9308022b85e4eea952d78124d1ee597")
64
-
65
- training_task_id = input_model.task
66
- training_task = Task.get_task(task_id=training_task_id)
67
-
68
- training_params = training_task.get_parameters()
69
- print(f"Training parameters: {training_params}")
70
 
71
- num_classes = int(training_params.get("General/num_classes"))
72
- channels = ast.literal_eval(training_params.get("General/channels"))
73
- dropout = float(training_params.get("General/dropout"))
74
- mean_nm = ast.literal_eval(training_params.get("General/normalize_mean"))
75
- std_nm = ast.literal_eval(training_params.get("General/normalize_std"))
76
- kernel_sizes = ast.literal_eval(training_params.get("General/kernel_sizes"))
 
 
 
77
 
78
- task.add_tags([model_name, "test"])
79
-
80
- dataset = test_loader.dataset
81
- class_names = dataset.features["label"].names
82
-
83
-
84
- model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout, kernel_sizes=kernel_sizes).to(device)
85
- model_path = input_model.get_local_copy()
86
- state_dict = torch.load(model_path, map_location=device)
87
- model.load_state_dict(state_dict)
88
-
89
- loss_fn = nn.CrossEntropyLoss()
90
 
91
- test_loss, test_acc, all_labels, all_preds, display_images, display_labels, display_preds = evaluate_on_test(model, test_loader,
92
- loss_fn, device,
93
- num_imgs=24)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- print("\nTest results:")
96
- print(f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc:.3f}")
97
- logger.report_scalar("loss", "test", test_loss, 0)
98
- logger.report_scalar("accuracy", "test", test_acc, 0)
99
 
100
- visualize_preds(display_images, display_labels, display_preds, logger, class_names, mean_nm, std_nm, num_images=24)
101
- plot_cfm(all_labels, all_preds, logger, class_names, num_classes)
102
 
103
  if __name__ == "__main__":
104
- main()
 
1
  import torch
2
  import torch.nn as nn
3
+ from data_prep import train_loader, val_loader, device
 
4
  from models.model import PlantCNN
5
  from utils.config import load_config
6
+ from clearml import Task
7
+ from pathlib import Path
 
8
  from tqdm.auto import tqdm
 
9
 
10
+ def train_step(model, loader, optimizer, loss_fn, device):
11
+ model.train()
12
+ running_loss = 0.0
13
+ correct = 0
14
+ total = 0
15
+
16
+ for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
17
+ images = batch["pixel_values"].to(device)
18
+ labels = batch["labels"].to(device)
19
+
20
+ optimizer.zero_grad()
21
+ output = model(images)
22
+ loss = loss_fn(output, labels)
23
+ loss.backward()
24
+ optimizer.step()
25
+
26
+ running_loss += loss.item()*labels.size(0)
27
+
28
+ _, preds = torch.max(output, dim=1)
29
+ correct += (preds==labels).sum().item()
30
+ total += labels.size(0)
31
+
32
+ epoch_loss = running_loss/total
33
+ epoch_acc = correct/total
34
+ return epoch_loss, epoch_acc
35
 
36
+ def test_step(model, loader, loss_fn, device):
37
  model.eval()
38
+
 
39
  running_loss = 0.0
40
  correct = 0
41
  total = 0
 
 
 
42
 
43
  with torch.no_grad():
44
  for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
 
54
  correct += (preds==labels).sum().item()
55
  total += labels.size(0)
56
 
57
+ epoch_loss = running_loss/total
58
+ epoch_acc = correct/total
59
+ return epoch_loss, epoch_acc
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def main():
62
+ config = load_config()
63
+ num_classes = config["num_classes"]
64
+ channels = config["channels"]
65
+ dropout = config["dropout"]
66
+ lr = config["lr"]
67
+ weight_decay = config["weight_decay"]
68
+ num_epochs = config["num_epochs"]
69
+ patience = config["early_stopping_patience"]
70
  project_name = "GAP_plant_disease_classification"
71
+ model_name="PlantCNN"
72
 
73
+ model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device)
74
+ loss_fn = nn.CrossEntropyLoss()
75
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
76
+
77
+ task = Task.init(project_name=project_name, task_name=f"{model_name}_training")
78
+ task.connect(config)
79
+ task.add_tags([model_name, "train"])
80
  logger = task.get_logger()
81
 
82
+ best_val_acc = 0.0
83
+ best_state_dict = None
84
+ patience_cnt = 0
 
 
 
 
85
 
86
+ for epoch in range(num_epochs):
87
+ print(f"\nEpoch: {epoch+1}/{num_epochs}")
88
+
89
+ train_loss, train_acc = train_step(
90
+ model, train_loader, optimizer, loss_fn, device
91
+ )
92
+ val_loss, val_acc = test_step(
93
+ model, val_loader, loss_fn, device
94
+ )
95
 
96
+ print(f"Train loss: {train_loss:.3f} | Train accuracy: {train_acc:.3f}")
97
+ print(f"Validation loss: {val_loss:.3f} | Validation accuracy: {val_acc:.3f}")
98
+ logger.report_scalar("loss", "train", train_loss, epoch)
99
+ logger.report_scalar("loss", "val", val_loss, epoch)
100
+ logger.report_scalar("accuracy", "train", train_acc, epoch)
101
+ logger.report_scalar("accuracy", "val", val_acc, epoch)
 
 
 
 
 
 
102
 
103
+ if val_acc > best_val_acc:
104
+ best_val_acc = val_acc
105
+ best_state_dict = model.state_dict()
106
+ patience_cnt = 0
107
+ else:
108
+ patience_cnt+=1
109
+
110
+ if patience_cnt >= patience:
111
+ print(f"\nEarly stopping triggered after {epoch+1} epochs.")
112
+ break
113
+
114
+ if best_state_dict is not None:
115
+ model.load_state_dict(best_state_dict)
116
+
117
+ project_rt = Path(__file__).resolve().parent
118
+ model_dir = project_rt/"saved_models"
119
+ model_dir.mkdir(parents=True, exist_ok=True)
120
+ model_path = model_dir/"plant_cnn.pt"
121
 
122
+ torch.save(model.state_dict(), model_path)
123
+ print(f"Saved best model to {model_path}")
 
 
124
 
125
+ task.update_output_model(model_path=str(model_path), name="plant_cnn_best")
 
126
 
127
  if __name__ == "__main__":
128
+ main()