import datetime import os from torch.nn import Linear from torchvision.transforms import v2 import data.dataset from torch.optim.lr_scheduler import CosineAnnealingLR import pandas from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision, MulticlassF1Score from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights from torch import nn, optim import torch from tqdm import tqdm from torch.utils.data import random_split import mlflow if __name__ == '__main__': mlflow.set_tracking_uri('http://localhost:5000') curr_date = datetime.datetime.now() os.mkdir(f"outputs/{curr_date}") # Input data files are available in the read-only "../input/" directory # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory with mlflow.start_run(): device = torch.device("cuda") augmentation_transforms = v2.Compose([ v2.RandomHorizontalFlip(), v2.RandomVerticalFlip(), v2.RandomGrayscale(), v2.RandomAutocontrast(), v2.RandomRotation(45), ]).to("cuda") dataset = data.dataset.OrangeDataset("/home/jarric/orange_dataset/processed/FIELD IMAGES/") train_size = int(0.75 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # Create data loaders batch_size = 32 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=False, num_workers=8, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, ) epochs = 100 mobilenet_v3_model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1.DEFAULT).to(device) for param in mobilenet_v3_model.parameters(): # freeze layers param.requires_grad = False mobilenet_v3_model.classifier[3] = Linear(in_features=1024, out_features=3, bias=True) mobilenet_v3_model.classifier[3].requires_grad = True mobilenet_v3_model.cuda() loss_fn = nn.CrossEntropyLoss().to(device) optimizer = optim.AdamW( mobilenet_v3_model.parameters(), lr=1e-4, weight_decay=1e-4) reporting_interval_train = 50 reporting_interval_val = 10 loss_fn = loss_fn.to(device) # metrics acc_metric = MulticlassAccuracy(num_classes=3).to(device) ap_metric = MulticlassAveragePrecision(num_classes=3, average="macro").to(device) f1_metric = MulticlassF1Score(num_classes=3).to(device) train_step = 0 val_step = 0 for epoch in tqdm(range(0, epochs)): train_loss = 0 avg_accuracy = 0 cur_iter = 0 average_precision = 0 f1_score_avg = 0 mobilenet_v3_model.train() for images, labels in tqdm(train_loader, leave=False): images, labels = images.to(device), labels.to(device) images = augmentation_transforms(images) outputs = mobilenet_v3_model(images) loss = loss_fn(outputs, labels) train_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() _, predicted = torch.max(outputs.data, 1) avg_accuracy += acc_metric(predicted, labels) average_precision += ap_metric(outputs, labels) f1_score_avg += f1_metric(predicted, labels) train_loss /= len(train_loader) avg_accuracy /= len(train_loader) average_precision /= len(train_loader) f1_score_avg /= len(train_loader) mlflow.log_metric("train_loss", train_loss, step=epoch) mlflow.log_metric("train_avg_accuracy", avg_accuracy, step=epoch) mlflow.log_metric("train_average_precision", average_precision, step=epoch) mlflow.log_metric("f1_score_avg", f1_score_avg, step=epoch) val_loss = 0 val_accuracy = 0 cur_iter = 0 average_precision_val = 0 f1_score_avg_val = 0 mobilenet_v3_model.eval() with torch.no_grad(): for images, labels in tqdm(val_loader, leave=False): images, labels = images.to(device), labels.to(device) # Forward pass outputs = mobilenet_v3_model(images) loss = loss_fn(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) average_precision_val += ap_metric(outputs, labels) val_accuracy += acc_metric(predicted, labels) f1_score_avg_val += f1_metric(predicted, labels) val_loss /= len(val_loader) val_accuracy /= len(val_loader) average_precision_val /= len(val_loader) f1_score_avg_val /= len(val_loader) mlflow.log_metric("val_loss", val_loss, step=epoch) mlflow.log_metric("val_avg_accuracy", val_accuracy, step=epoch) mlflow.log_metric("val_average_precision", average_precision_val, step=epoch) mlflow.log_metric("val_f1_score_avg", f1_score_avg_val, step=epoch) torch.save(mobilenet_v3_model, f"outputs/{curr_date}/model_{epoch}_finetuned.pt") mlflow.log_artifact(f"outputs/{curr_date}/model_{epoch}_finetuned.pt")