Geothermal_model / model.py
Kamalikinuthia's picture
update (expected device type) argument
c733713 verified
import os
import argparse
import mlflow
import mlflow.pytorch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from datasets import load_dataset
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import KFold
from tqdm import tqdm
# Define argument parser for configuration
parser = argparse.ArgumentParser(description='Geothermal Classification Training')
parser.add_argument('--batch_size', type=int, default=32, help='batch size for training')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--n_splits', type=int, default=5, help='number of folds for cross-validation')
parser.add_argument('--test_image', type=str, help='path to external image for testing')
args = parser.parse_args(['--batch_size', '32',
'--epochs', '50',
'--lr', '0.001',
'--n_splits', '5'])
# Set up MLflow
mlflow.set_experiment("Geothermal Classification without Metadata")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the transformations with data augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class GeothermalNet(nn.Module):
def __init__(self, num_classes):
super(GeothermalNet, self).__init__()
self.resnet = models.resnet18(weights='DEFAULT')
self.resnet.fc = nn.Sequential(
nn.Linear(self.resnet.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, image):
return self.resnet(image)
class CustomDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = self.images[idx]
if img.mode=='RGBA':
img = img.convert('RGB')
if self.transform:
img = self.transform(img)
label = self.labels[idx]
return img, label
def create_model(num_classes):
return GeothermalNet(num_classes)
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
best_val_loss = float('inf')
patience = 10
early_stopping_counter = 0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
train_preds, train_labels = [], []
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
with torch.amp.autocast():
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
train_preds.extend(preds.cpu().numpy())
train_labels.extend(labels.cpu().numpy())
epoch_loss = running_loss / len(train_loader.dataset)
train_acc = accuracy_score(train_labels, train_preds)
train_f1 = f1_score(train_labels, train_preds, average='weighted')
model.eval()
val_loss = 0.0
val_preds, val_labels = [], []
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
with torch.amp.autocast():
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
val_preds.extend(preds.cpu().numpy())
val_labels.extend(labels.cpu().numpy())
val_loss /= len(val_loader.dataset)
val_acc = accuracy_score(val_labels, val_preds)
val_f1 = f1_score(val_labels, val_preds, average='weighted')
scheduler.step(val_loss)
mlflow.log_metric("train_loss", epoch_loss, step=epoch)
mlflow.log_metric("train_acc", train_acc, step=epoch)
mlflow.log_metric("train_f1", train_f1, step=epoch)
mlflow.log_metric("val_loss", val_loss, step=epoch)
mlflow.log_metric("val_acc", val_acc, step=epoch)
mlflow.log_metric("val_f1", val_f1, step=epoch)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.4f}, '
f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
early_stopping_counter = 0
else:
early_stopping_counter += 1
if early_stopping_counter >= patience:
print("Early stopping triggered")
break
return model
def load_model(model_path, num_classes):
model = create_model(num_classes)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def preprocess_image(image_path):
image = Image.open(image_path).convert("RGB")
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return preprocess(image).unsqueeze(0)
#function to test on external images(images not in the dataset)
# def test_external_image(model, image_path, device):
# model.eval()
# image = preprocess_image(image_path).to(device)
# with torch.no_grad():
# outputs = model(image)
# _, predicted = torch.max(outputs, 1)
# return predicted.item()
def main():
# Load and prepare dataset
try:
dataset = load_dataset("Kamalikinuthia/geothermal-dataset")
train_images = dataset['train']['image']
train_labels = dataset['train']['label']
except Exception as e:
print(f"Error loading dataset: {e}")
exit(1)
full_dataset = CustomDataset(images=train_images, labels=train_labels, transform=train_transform)
# Cross-validation
kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset)):
print(f"Fold {fold+1}")
with mlflow.start_run(run_name=f"fold_{fold+1}"):
mlflow.log_params(vars(args))
train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
train_loader = DataLoader(full_dataset, batch_size=args.batch_size, sampler=train_subsampler)
val_loader = DataLoader(full_dataset, batch_size=args.batch_size, sampler=val_subsampler)
model = create_model(num_classes=len(set(train_labels))).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1)
model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, args.epochs)
# Test the model
model.eval()
test_preds, test_labels = [], []
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
test_preds.extend(preds.cpu().numpy())
test_labels.extend(labels.cpu().numpy())
test_acc = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='weighted')
mlflow.log_metric("test_acc", test_acc)
mlflow.log_metric("test_f1", test_f1)
print(f"Fold {fold+1} Test Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}")
# # test with external image
# if args.test_image:
# prediction = test_external_image(model, args.test_image, device)
# print(f"Prediction for external image: {prediction}")
if __name__ == "__main__":
main()