Kamalikinuthia commited on
Commit
7eaf521
·
verified ·
1 Parent(s): 43534eb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +245 -0
model.py CHANGED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import mlflow
4
+ import mlflow.pytorch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ from datasets import load_dataset
8
+ from torchvision import transforms
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torchvision import models
13
+ from torch.utils.data import random_split
14
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
15
+ from sklearn.metrics import accuracy_score, f1_score
16
+ from sklearn.model_selection import KFold
17
+ from tqdm import tqdm
18
+
19
+ # Define argument parser for configuration
20
+ parser = argparse.ArgumentParser(description='Geothermal Classification Training')
21
+ parser.add_argument('--batch_size', type=int, default=32, help='batch size for training')
22
+ parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train')
23
+ parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
24
+ parser.add_argument('--n_splits', type=int, default=5, help='number of folds for cross-validation')
25
+ parser.add_argument('--test_image', type=str, help='path to external image for testing')
26
+ args = parser.parse_args(['--batch_size', '32',
27
+ '--epochs', '50',
28
+ '--lr', '0.001',
29
+ '--n_splits', '5'])
30
+
31
+ # Set up MLflow
32
+ mlflow.set_experiment("Geothermal Classification without Metadata")
33
+
34
+ # Set device
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # Define the transformations with data augmentation
38
+ train_transform = transforms.Compose([
39
+ transforms.RandomResizedCrop(224),
40
+ transforms.RandomHorizontalFlip(),
41
+ transforms.RandomRotation(15),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44
+ ])
45
+
46
+ val_transform = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+ ])
51
+
52
+ class GeothermalNet(nn.Module):
53
+ def __init__(self, num_classes):
54
+ super(GeothermalNet, self).__init__()
55
+ self.resnet = models.resnet18(weights='DEFAULT')
56
+ self.resnet.fc = nn.Sequential(
57
+ nn.Linear(self.resnet.fc.in_features, 256),
58
+ nn.ReLU(),
59
+ nn.Dropout(0.5),
60
+ nn.Linear(256, num_classes)
61
+ )
62
+
63
+ def forward(self, image):
64
+ return self.resnet(image)
65
+
66
+ class CustomDataset(Dataset):
67
+ def __init__(self, images, labels, transform=None):
68
+ self.images = images
69
+ self.labels = labels
70
+ self.transform = transform
71
+
72
+ def __len__(self):
73
+ return len(self.images)
74
+
75
+ def __getitem__(self, idx):
76
+ img = self.images[idx]
77
+ if img.mode=='RGBA':
78
+ img = img.convert('RGB')
79
+
80
+ if self.transform:
81
+ img = self.transform(img)
82
+
83
+ label = self.labels[idx]
84
+ return img, label
85
+
86
+ def create_model(num_classes):
87
+ return GeothermalNet(num_classes)
88
+
89
+ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
90
+ best_val_loss = float('inf')
91
+ patience = 10
92
+ early_stopping_counter = 0
93
+
94
+ for epoch in range(num_epochs):
95
+ model.train()
96
+ running_loss = 0.0
97
+ train_preds, train_labels = [], []
98
+
99
+ for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
100
+ images, labels = images.to(device), labels.to(device)
101
+ optimizer.zero_grad()
102
+ with torch.amp.autocast():
103
+ outputs = model(images)
104
+ loss = criterion(outputs, labels)
105
+ loss.backward()
106
+ optimizer.step()
107
+
108
+ running_loss += loss.item() * images.size(0)
109
+ _, preds = torch.max(outputs, 1)
110
+ train_preds.extend(preds.cpu().numpy())
111
+ train_labels.extend(labels.cpu().numpy())
112
+
113
+ epoch_loss = running_loss / len(train_loader.dataset)
114
+ train_acc = accuracy_score(train_labels, train_preds)
115
+ train_f1 = f1_score(train_labels, train_preds, average='weighted')
116
+
117
+ model.eval()
118
+ val_loss = 0.0
119
+ val_preds, val_labels = [], []
120
+ with torch.no_grad():
121
+ for images, labels in val_loader:
122
+ images, labels = images.to(device), labels.to(device)
123
+ with torch.amp.autocast():
124
+ outputs = model(images)
125
+ loss = criterion(outputs, labels)
126
+ val_loss += loss.item() * images.size(0)
127
+ _, preds = torch.max(outputs, 1)
128
+ val_preds.extend(preds.cpu().numpy())
129
+ val_labels.extend(labels.cpu().numpy())
130
+
131
+ val_loss /= len(val_loader.dataset)
132
+ val_acc = accuracy_score(val_labels, val_preds)
133
+ val_f1 = f1_score(val_labels, val_preds, average='weighted')
134
+
135
+ scheduler.step(val_loss)
136
+
137
+ mlflow.log_metric("train_loss", epoch_loss, step=epoch)
138
+ mlflow.log_metric("train_acc", train_acc, step=epoch)
139
+ mlflow.log_metric("train_f1", train_f1, step=epoch)
140
+ mlflow.log_metric("val_loss", val_loss, step=epoch)
141
+ mlflow.log_metric("val_acc", val_acc, step=epoch)
142
+ mlflow.log_metric("val_f1", val_f1, step=epoch)
143
+
144
+ print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.4f}, '
145
+ f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
146
+
147
+ if val_loss < best_val_loss:
148
+ best_val_loss = val_loss
149
+ torch.save(model.state_dict(), 'best_model.pth')
150
+ early_stopping_counter = 0
151
+ else:
152
+ early_stopping_counter += 1
153
+
154
+ if early_stopping_counter >= patience:
155
+ print("Early stopping triggered")
156
+ break
157
+
158
+ return model
159
+
160
+ def load_model(model_path, num_classes):
161
+ model = create_model(num_classes)
162
+ model.load_state_dict(torch.load(model_path))
163
+ model.eval()
164
+ return model
165
+
166
+ def preprocess_image(image_path):
167
+ image = Image.open(image_path).convert("RGB")
168
+ preprocess = transforms.Compose([
169
+ transforms.Resize((224, 224)),
170
+ transforms.ToTensor(),
171
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
172
+ ])
173
+ return preprocess(image).unsqueeze(0)
174
+
175
+ #function to test on external images(images not in the dataset)
176
+ # def test_external_image(model, image_path, device):
177
+ # model.eval()
178
+ # image = preprocess_image(image_path).to(device)
179
+
180
+ # with torch.no_grad():
181
+ # outputs = model(image)
182
+ # _, predicted = torch.max(outputs, 1)
183
+
184
+ # return predicted.item()
185
+
186
+ def main():
187
+ # Load and prepare dataset
188
+ try:
189
+ dataset = load_dataset("Kamalikinuthia/geothermal-dataset")
190
+ train_images = dataset['train']['image']
191
+ train_labels = dataset['train']['label']
192
+ except Exception as e:
193
+ print(f"Error loading dataset: {e}")
194
+ exit(1)
195
+
196
+ full_dataset = CustomDataset(images=train_images, labels=train_labels, transform=train_transform)
197
+
198
+ # Cross-validation
199
+ kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=42)
200
+
201
+ for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset)):
202
+ print(f"Fold {fold+1}")
203
+
204
+ with mlflow.start_run(run_name=f"fold_{fold+1}"):
205
+ mlflow.log_params(vars(args))
206
+
207
+ train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
208
+ val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
209
+
210
+ train_loader = DataLoader(full_dataset, batch_size=args.batch_size, sampler=train_subsampler)
211
+ val_loader = DataLoader(full_dataset, batch_size=args.batch_size, sampler=val_subsampler)
212
+
213
+ model = create_model(num_classes=len(set(train_labels))).to(device)
214
+ criterion = nn.CrossEntropyLoss()
215
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
216
+ scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1)
217
+
218
+ model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, args.epochs)
219
+
220
+ # Test the model
221
+ model.eval()
222
+ test_preds, test_labels = [], []
223
+ with torch.no_grad():
224
+ for images, labels in val_loader:
225
+ images, labels = images.to(device), labels.to(device)
226
+ outputs = model(images)
227
+ _, preds = torch.max(outputs, 1)
228
+ test_preds.extend(preds.cpu().numpy())
229
+ test_labels.extend(labels.cpu().numpy())
230
+
231
+ test_acc = accuracy_score(test_labels, test_preds)
232
+ test_f1 = f1_score(test_labels, test_preds, average='weighted')
233
+
234
+ mlflow.log_metric("test_acc", test_acc)
235
+ mlflow.log_metric("test_f1", test_f1)
236
+
237
+ print(f"Fold {fold+1} Test Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}")
238
+
239
+ # # test with external image
240
+ # if args.test_image:
241
+ # prediction = test_external_image(model, args.test_image, device)
242
+ # print(f"Prediction for external image: {prediction}")
243
+
244
+ if __name__ == "__main__":
245
+ main()