OrangeRecognizer / main.py
jarric's picture
Initial upload
e81e6d0 verified
import datetime
import os
from torch.nn import Linear
from torchvision.transforms import v2
import data.dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
import pandas
from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision, MulticlassF1Score
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torch import nn, optim
import torch
from tqdm import tqdm
from torch.utils.data import random_split
import mlflow
if __name__ == '__main__':
mlflow.set_tracking_uri('http://localhost:5000')
curr_date = datetime.datetime.now()
os.mkdir(f"outputs/{curr_date}")
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
with mlflow.start_run():
device = torch.device("cuda")
augmentation_transforms = v2.Compose([
v2.RandomHorizontalFlip(),
v2.RandomVerticalFlip(),
v2.RandomGrayscale(),
v2.RandomAutocontrast(),
v2.RandomRotation(45),
]).to("cuda")
dataset = data.dataset.OrangeDataset("/home/jarric/orange_dataset/processed/FIELD IMAGES/")
train_size = int(0.75 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
)
epochs = 100
mobilenet_v3_model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1.DEFAULT).to(device)
for param in mobilenet_v3_model.parameters(): # freeze layers
param.requires_grad = False
mobilenet_v3_model.classifier[3] = Linear(in_features=1024, out_features=3, bias=True)
mobilenet_v3_model.classifier[3].requires_grad = True
mobilenet_v3_model.cuda()
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(
mobilenet_v3_model.parameters(),
lr=1e-4,
weight_decay=1e-4)
reporting_interval_train = 50
reporting_interval_val = 10
loss_fn = loss_fn.to(device)
# metrics
acc_metric = MulticlassAccuracy(num_classes=3).to(device)
ap_metric = MulticlassAveragePrecision(num_classes=3, average="macro").to(device)
f1_metric = MulticlassF1Score(num_classes=3).to(device)
train_step = 0
val_step = 0
for epoch in tqdm(range(0, epochs)):
train_loss = 0
avg_accuracy = 0
cur_iter = 0
average_precision = 0
f1_score_avg = 0
mobilenet_v3_model.train()
for images, labels in tqdm(train_loader, leave=False):
images, labels = images.to(device), labels.to(device)
images = augmentation_transforms(images)
outputs = mobilenet_v3_model(images)
loss = loss_fn(outputs, labels)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
avg_accuracy += acc_metric(predicted, labels)
average_precision += ap_metric(outputs, labels)
f1_score_avg += f1_metric(predicted, labels)
train_loss /= len(train_loader)
avg_accuracy /= len(train_loader)
average_precision /= len(train_loader)
f1_score_avg /= len(train_loader)
mlflow.log_metric("train_loss", train_loss, step=epoch)
mlflow.log_metric("train_avg_accuracy", avg_accuracy, step=epoch)
mlflow.log_metric("train_average_precision", average_precision, step=epoch)
mlflow.log_metric("f1_score_avg", f1_score_avg, step=epoch)
val_loss = 0
val_accuracy = 0
cur_iter = 0
average_precision_val = 0
f1_score_avg_val = 0
mobilenet_v3_model.eval()
with torch.no_grad():
for images, labels in tqdm(val_loader, leave=False):
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = mobilenet_v3_model(images)
loss = loss_fn(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
average_precision_val += ap_metric(outputs, labels)
val_accuracy += acc_metric(predicted, labels)
f1_score_avg_val += f1_metric(predicted, labels)
val_loss /= len(val_loader)
val_accuracy /= len(val_loader)
average_precision_val /= len(val_loader)
f1_score_avg_val /= len(val_loader)
mlflow.log_metric("val_loss", val_loss, step=epoch)
mlflow.log_metric("val_avg_accuracy", val_accuracy, step=epoch)
mlflow.log_metric("val_average_precision", average_precision_val, step=epoch)
mlflow.log_metric("val_f1_score_avg", f1_score_avg_val, step=epoch)
torch.save(mobilenet_v3_model, f"outputs/{curr_date}/model_{epoch}_finetuned.pt")
mlflow.log_artifact(f"outputs/{curr_date}/model_{epoch}_finetuned.pt")