|
|
import os |
|
|
import argparse |
|
|
import mlflow |
|
|
import mlflow.pytorch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from PIL import Image |
|
|
from datasets import load_dataset |
|
|
from torchvision import transforms |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torchvision import models |
|
|
from torch.utils.data import random_split |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
from sklearn.metrics import accuracy_score, f1_score |
|
|
from sklearn.model_selection import KFold |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Geothermal Classification Training') |
|
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size for training') |
|
|
parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train') |
|
|
parser.add_argument('--lr', type=float, default=0.001, help='learning rate') |
|
|
parser.add_argument('--n_splits', type=int, default=5, help='number of folds for cross-validation') |
|
|
parser.add_argument('--test_image', type=str, help='path to external image for testing') |
|
|
args = parser.parse_args(['--batch_size', '32', |
|
|
'--epochs', '50', |
|
|
'--lr', '0.001', |
|
|
'--n_splits', '5']) |
|
|
|
|
|
|
|
|
mlflow.set_experiment("Geothermal Classification without Metadata") |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
train_transform = transforms.Compose([ |
|
|
transforms.RandomResizedCrop(224), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.RandomRotation(15), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
val_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
class GeothermalNet(nn.Module): |
|
|
def __init__(self, num_classes): |
|
|
super(GeothermalNet, self).__init__() |
|
|
self.resnet = models.resnet18(weights='DEFAULT') |
|
|
self.resnet.fc = nn.Sequential( |
|
|
nn.Linear(self.resnet.fc.in_features, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(256, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, image): |
|
|
return self.resnet(image) |
|
|
|
|
|
class CustomDataset(Dataset): |
|
|
def __init__(self, images, labels, transform=None): |
|
|
self.images = images |
|
|
self.labels = labels |
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
img = self.images[idx] |
|
|
if img.mode=='RGBA': |
|
|
img = img.convert('RGB') |
|
|
|
|
|
if self.transform: |
|
|
img = self.transform(img) |
|
|
|
|
|
label = self.labels[idx] |
|
|
return img, label |
|
|
|
|
|
def create_model(num_classes): |
|
|
return GeothermalNet(num_classes) |
|
|
|
|
|
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs): |
|
|
best_val_loss = float('inf') |
|
|
patience = 10 |
|
|
early_stopping_counter = 0 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
model.train() |
|
|
running_loss = 0.0 |
|
|
train_preds, train_labels = [], [] |
|
|
|
|
|
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
optimizer.zero_grad() |
|
|
with torch.amp.autocast(): |
|
|
outputs = model(images) |
|
|
loss = criterion(outputs, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
running_loss += loss.item() * images.size(0) |
|
|
_, preds = torch.max(outputs, 1) |
|
|
train_preds.extend(preds.cpu().numpy()) |
|
|
train_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
epoch_loss = running_loss / len(train_loader.dataset) |
|
|
train_acc = accuracy_score(train_labels, train_preds) |
|
|
train_f1 = f1_score(train_labels, train_preds, average='weighted') |
|
|
|
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
val_preds, val_labels = [], [] |
|
|
with torch.no_grad(): |
|
|
for images, labels in val_loader: |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
with torch.amp.autocast(): |
|
|
outputs = model(images) |
|
|
loss = criterion(outputs, labels) |
|
|
val_loss += loss.item() * images.size(0) |
|
|
_, preds = torch.max(outputs, 1) |
|
|
val_preds.extend(preds.cpu().numpy()) |
|
|
val_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
val_loss /= len(val_loader.dataset) |
|
|
val_acc = accuracy_score(val_labels, val_preds) |
|
|
val_f1 = f1_score(val_labels, val_preds, average='weighted') |
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
mlflow.log_metric("train_loss", epoch_loss, step=epoch) |
|
|
mlflow.log_metric("train_acc", train_acc, step=epoch) |
|
|
mlflow.log_metric("train_f1", train_f1, step=epoch) |
|
|
mlflow.log_metric("val_loss", val_loss, step=epoch) |
|
|
mlflow.log_metric("val_acc", val_acc, step=epoch) |
|
|
mlflow.log_metric("val_f1", val_f1, step=epoch) |
|
|
|
|
|
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.4f}, ' |
|
|
f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}') |
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
torch.save(model.state_dict(), 'best_model.pth') |
|
|
early_stopping_counter = 0 |
|
|
else: |
|
|
early_stopping_counter += 1 |
|
|
|
|
|
if early_stopping_counter >= patience: |
|
|
print("Early stopping triggered") |
|
|
break |
|
|
|
|
|
return model |
|
|
|
|
|
def load_model(model_path, num_classes): |
|
|
model = create_model(num_classes) |
|
|
model.load_state_dict(torch.load(model_path)) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def preprocess_image(image_path): |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
return preprocess(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
try: |
|
|
dataset = load_dataset("Kamalikinuthia/geothermal-dataset") |
|
|
train_images = dataset['train']['image'] |
|
|
train_labels = dataset['train']['label'] |
|
|
except Exception as e: |
|
|
print(f"Error loading dataset: {e}") |
|
|
exit(1) |
|
|
|
|
|
full_dataset = CustomDataset(images=train_images, labels=train_labels, transform=train_transform) |
|
|
|
|
|
|
|
|
kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=42) |
|
|
|
|
|
for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset)): |
|
|
print(f"Fold {fold+1}") |
|
|
|
|
|
with mlflow.start_run(run_name=f"fold_{fold+1}"): |
|
|
mlflow.log_params(vars(args)) |
|
|
|
|
|
train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx) |
|
|
val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx) |
|
|
|
|
|
train_loader = DataLoader(full_dataset, batch_size=args.batch_size, sampler=train_subsampler) |
|
|
val_loader = DataLoader(full_dataset, batch_size=args.batch_size, sampler=val_subsampler) |
|
|
|
|
|
model = create_model(num_classes=len(set(train_labels))).to(device) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr) |
|
|
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1) |
|
|
|
|
|
model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, args.epochs) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
test_preds, test_labels = [], [] |
|
|
with torch.no_grad(): |
|
|
for images, labels in val_loader: |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
outputs = model(images) |
|
|
_, preds = torch.max(outputs, 1) |
|
|
test_preds.extend(preds.cpu().numpy()) |
|
|
test_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
test_acc = accuracy_score(test_labels, test_preds) |
|
|
test_f1 = f1_score(test_labels, test_preds, average='weighted') |
|
|
|
|
|
mlflow.log_metric("test_acc", test_acc) |
|
|
mlflow.log_metric("test_f1", test_f1) |
|
|
|
|
|
print(f"Fold {fold+1} Test Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|