fatimaxa commited on
Commit
5762bbf
·
verified ·
1 Parent(s): 40ce994

Upload 3 files

Browse files
Files changed (3) hide show
  1. data_prep.py +146 -0
  2. test.py +93 -0
  3. train.py +104 -0
data_prep.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from datasets import load_dataset
5
+ from utils.config import load_config
6
+
7
+ config = load_config()
8
+ batch_size = config["batch_size"]
9
+ num_workers = config["num_workers"]
10
+ mean_nm = config["normalize_mean"]
11
+ std_nm = config["normalize_std"]
12
+ execute_remotely = config.get("execute_remotely", False)
13
+
14
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # set dataset to clearml dataset if executing remotely or load from huggingface otherwise
17
+ if config["execute_remotely"]:
18
+ from clearml import Dataset as ClearMLDataset
19
+ clearml_dataset = ClearMLDataset.get(dataset_id="0c3de7af2d98482dacf41633a0587845")
20
+ dataset_path = clearml_dataset.get_local_copy()
21
+ dataset = load_dataset(dataset_path)
22
+ else:
23
+ dataset = load_dataset("DScomp380/plant_village", cache_dir="./data_cache")
24
+ #split dataset into train(70%), and 30% remaining for val and test
25
+ splits = dataset["train"].train_test_split(test_size=0.30, seed=42)
26
+ train_split = splits["train"] #training set
27
+ remaining = splits["test"]
28
+
29
+ #split remaining 30% into val(15%) and test(15%)
30
+ val_test = remaining.train_test_split(test_size=0.5, seed=42)
31
+ val_split = val_test["train"] #validation set
32
+ test_split = val_test["test"] #test set
33
+
34
+ preprocess_transform = transforms.Compose([
35
+ # resize images to 224x224, convert to tensor, and normalize
36
+ transforms.Resize((224, 224)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=mean_nm, std=std_nm)
39
+ ])
40
+
41
+ def preprocess_batch(batch):
42
+ batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
43
+ return batch
44
+
45
+ if execute_remotely:
46
+ def train_transform_batch(batch):
47
+ batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
48
+ return batch
49
+ train_split = train_split.with_transform(train_transform_batch)
50
+ val_split = val_split.with_transform(train_transform_batch)
51
+ test_split = test_split.with_transform(train_transform_batch)
52
+ else:
53
+ train_split = train_split.map(
54
+ preprocess_batch,
55
+ batched=True,
56
+ batch_size=100,
57
+ remove_columns=["image"],
58
+ cache_file_name="./data_cache/train_preprocessed.arrow"
59
+ )
60
+
61
+ val_split = val_split.map(
62
+ preprocess_batch,
63
+ batched=True,
64
+ batch_size=100,
65
+ remove_columns=["image"],
66
+ cache_file_name="./data_cache/val_preprocessed.arrow"
67
+ )
68
+
69
+ test_split = test_split.map(
70
+ preprocess_batch,
71
+ batched=True,
72
+ batch_size=100,
73
+ remove_columns=["image"],
74
+ cache_file_name="./data_cache/test_preprocessed.arrow"
75
+ )
76
+
77
+ train_split.set_format(type="torch", columns=["pixel_values", "label"])
78
+ val_split.set_format(type="torch", columns=["pixel_values", "label"])
79
+ test_split.set_format(type="torch", columns=["pixel_values", "label"])
80
+
81
+ # augmentations
82
+ train_augment = transforms.Compose([
83
+ transforms.RandomHorizontalFlip(p=0.5),
84
+ transforms.RandomVerticalFlip(p=0.3),
85
+ transforms.RandomRotation(degrees=15),
86
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
87
+ transforms.RandomApply([
88
+ transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
89
+ ], p=0.3),
90
+ ])
91
+
92
+
93
+ def train_collate_fn(batch):
94
+ pixel_values = [item["pixel_values"] for item in batch]
95
+ labels = [item["label"] for item in batch]
96
+
97
+ augmented = [train_augment(img) for img in pixel_values] # apply augmentation while training
98
+
99
+ return {
100
+ "pixel_values": torch.stack(augmented),
101
+ "labels": torch.tensor(labels)
102
+ }
103
+
104
+ def val_collate_fn(batch):
105
+ return {
106
+ "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
107
+ "labels": torch.tensor([item["label"] for item in batch])
108
+ }
109
+
110
+ # create DataLoaders for train, val, and test sets
111
+ train_loader = DataLoader(
112
+ train_split,
113
+ batch_size=batch_size,
114
+ shuffle=True,
115
+ num_workers=num_workers,
116
+ pin_memory=True,
117
+ persistent_workers=True if num_workers > 0 else False,
118
+ collate_fn=train_collate_fn
119
+ )
120
+
121
+ val_loader = DataLoader(
122
+ val_split,
123
+ batch_size=batch_size,
124
+ shuffle=False,
125
+ num_workers=num_workers,
126
+ pin_memory=True,
127
+ persistent_workers=True if num_workers > 0 else False,
128
+ collate_fn=val_collate_fn
129
+ )
130
+
131
+ test_loader = DataLoader(
132
+ test_split,
133
+ batch_size=batch_size,
134
+ shuffle=False,
135
+ num_workers=num_workers,
136
+ pin_memory=True,
137
+ persistent_workers=True if num_workers > 0 else False,
138
+ collate_fn=val_collate_fn
139
+ )
140
+
141
+ if __name__ == "__main__":
142
+ print(f"Device: {device}")
143
+ print(f"Train samples: {len(train_split)}")
144
+ print(f"Val samples: {len(val_split)}")
145
+ print(f"Test samples: {len(test_split)}")
146
+ print(f"Batches per epoch: {len(train_loader)}")
test.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from pathlib import Path
4
+ from data_prep import test_loader, device
5
+ from models.model import PlantCNN
6
+ from utils.config import load_config
7
+ from clearml import Task
8
+ import numpy as np
9
+ from utils.vis import visualize_preds, plot_cfm
10
+ from tqdm.auto import tqdm
11
+
12
+
13
+ def evaluate_on_test(model, loader, loss_fn, device, num_imgs):
14
+ model.eval()
15
+ all_labels = []
16
+ all_preds = []
17
+ running_loss = 0.0
18
+ correct = 0
19
+ total = 0
20
+ imgs_to_display = []
21
+ lbls_to_display = []
22
+ prs_to_display = []
23
+
24
+ with torch.no_grad():
25
+ for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
26
+ images = batch["pixel_values"].to(device)
27
+ labels = batch["labels"].to(device)
28
+
29
+ output = model(images)
30
+ loss = loss_fn(output, labels)
31
+
32
+ running_loss += loss.item()*labels.size(0)
33
+
34
+ _, preds = torch.max(output, dim=1)
35
+ correct += (preds==labels).sum().item()
36
+ total += labels.size(0)
37
+
38
+ all_labels.extend(labels.cpu().numpy())
39
+ all_preds.extend(preds.cpu().numpy())
40
+
41
+ if len(imgs_to_display) < num_imgs:
42
+ remaining = num_imgs - len(imgs_to_display)
43
+ for img, lbl, pr in zip(images[:remaining], preds[:remaining], preds[:remaining]):
44
+ imgs_to_display.append(img.cpu())
45
+ lbls_to_display.append(lbl.item())
46
+ prs_to_display.append(pr.item())
47
+
48
+ test_loss = running_loss / total
49
+ test_acc = correct / total
50
+ return test_loss, test_acc, all_labels, all_preds, imgs_to_display, lbls_to_display, prs_to_display
51
+
52
+
53
+ def main():
54
+ config = load_config()
55
+ num_classes = config["num_classes"]
56
+ channels = config["channels"]
57
+ dropout = config["dropout"]
58
+ lr = config["lr"]
59
+ project_name = "GAP_plant_disease_classification"
60
+ model_name = "PlantCNN"
61
+ mean_nm = config["normalize_mean"]
62
+ std_nm = config["normalize_std"]
63
+
64
+ task = Task.init(project_name=project_name, task_name=f"{model_name}_test")
65
+ task.connect(config)
66
+ task.add_tags([model_name, "test"])
67
+ logger = task.get_logger()
68
+
69
+ dataset = test_loader.dataset
70
+ class_names = dataset.features["label"].names
71
+
72
+ model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device)
73
+ project_root = Path(__file__).resolve().parent
74
+ model_path = project_root / "saved_models" / "plant_cnn.pt"
75
+ state_dict = torch.load(model_path, map_location=device)
76
+ model.load_state_dict(state_dict)
77
+
78
+ loss_fn = nn.CrossEntropyLoss()
79
+
80
+ test_loss, test_acc, all_labels, all_preds, display_images, display_labels, display_preds = evaluate_on_test(model, test_loader,
81
+ loss_fn, device,
82
+ num_imgs=24)
83
+
84
+ print("\nTest results:")
85
+ print(f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc:.3f}")
86
+ logger.report_scalar("loss", "test", test_loss, 0)
87
+ logger.report_scalar("accuracy", "test", test_acc, 0)
88
+
89
+ visualize_preds(display_images, display_labels, display_preds, logger, class_names, mean_nm, std_nm, num_images=24)
90
+ plot_cfm(all_labels, all_preds, logger, class_names, num_classes)
91
+
92
+ if __name__ == "__main__":
93
+ main()
train.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from pathlib import Path
4
+ from data_prep import test_loader, device
5
+ from models.model import PlantCNN
6
+ from utils.config import load_config
7
+ from clearml import Task, InputModel
8
+ import numpy as np
9
+ from utils.vis import visualize_preds, plot_cfm
10
+ from tqdm.auto import tqdm
11
+ import ast
12
+
13
+
14
+ def evaluate_on_test(model, loader, loss_fn, device, num_imgs):
15
+ model.eval()
16
+ all_labels = []
17
+ all_preds = []
18
+ running_loss = 0.0
19
+ correct = 0
20
+ total = 0
21
+ imgs_to_display = []
22
+ lbls_to_display = []
23
+ prs_to_display = []
24
+
25
+ with torch.no_grad():
26
+ for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
27
+ images = batch["pixel_values"].to(device)
28
+ labels = batch["labels"].to(device)
29
+
30
+ output = model(images)
31
+ loss = loss_fn(output, labels)
32
+
33
+ running_loss += loss.item()*labels.size(0)
34
+
35
+ _, preds = torch.max(output, dim=1)
36
+ correct += (preds==labels).sum().item()
37
+ total += labels.size(0)
38
+
39
+ all_labels.extend(labels.cpu().numpy())
40
+ all_preds.extend(preds.cpu().numpy())
41
+
42
+ if len(imgs_to_display) < num_imgs:
43
+ remaining = num_imgs - len(imgs_to_display)
44
+ for img, lbl, pr in zip(images[:remaining], preds[:remaining], preds[:remaining]):
45
+ imgs_to_display.append(img.cpu())
46
+ lbls_to_display.append(lbl.item())
47
+ prs_to_display.append(pr.item())
48
+
49
+ test_loss = running_loss / total
50
+ test_acc = correct / total
51
+ return test_loss, test_acc, all_labels, all_preds, imgs_to_display, lbls_to_display, prs_to_display
52
+
53
+
54
+
55
+
56
+ def main():
57
+ project_name = "GAP_plant_disease_classification"
58
+ model_name = "PlantCNN"
59
+
60
+ task = Task.init(project_name=project_name, task_name=f"{model_name}_test")
61
+ logger = task.get_logger()
62
+
63
+ input_model = InputModel(model_id="b9308022b85e4eea952d78124d1ee597")
64
+
65
+ training_task_id = input_model.task
66
+ training_task = Task.get_task(task_id=training_task_id)
67
+
68
+ training_params = training_task.get_parameters()
69
+ print(f"Training parameters: {training_params}")
70
+
71
+ num_classes = int(training_params.get("General/num_classes"))
72
+ channels = ast.literal_eval(training_params.get("General/channels"))
73
+ dropout = float(training_params.get("General/dropout"))
74
+ mean_nm = ast.literal_eval(training_params.get("General/normalize_mean"))
75
+ std_nm = ast.literal_eval(training_params.get("General/normalize_std"))
76
+ kernel_sizes = ast.literal_eval(training_params.get("General/kernel_sizes"))
77
+
78
+ task.add_tags([model_name, "test"])
79
+
80
+ dataset = test_loader.dataset
81
+ class_names = dataset.features["label"].names
82
+
83
+
84
+ model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout, kernel_sizes=kernel_sizes).to(device)
85
+ model_path = input_model.get_local_copy()
86
+ state_dict = torch.load(model_path, map_location=device)
87
+ model.load_state_dict(state_dict)
88
+
89
+ loss_fn = nn.CrossEntropyLoss()
90
+
91
+ test_loss, test_acc, all_labels, all_preds, display_images, display_labels, display_preds = evaluate_on_test(model, test_loader,
92
+ loss_fn, device,
93
+ num_imgs=24)
94
+
95
+ print("\nTest results:")
96
+ print(f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc:.3f}")
97
+ logger.report_scalar("loss", "test", test_loss, 0)
98
+ logger.report_scalar("accuracy", "test", test_acc, 0)
99
+
100
+ visualize_preds(display_images, display_labels, display_preds, logger, class_names, mean_nm, std_nm, num_images=24)
101
+ plot_cfm(all_labels, all_preds, logger, class_names, num_classes)
102
+
103
+ if __name__ == "__main__":
104
+ main()