|
|
import datetime |
|
|
import os |
|
|
|
|
|
from torch.nn import Linear |
|
|
from torchvision.transforms import v2 |
|
|
|
|
|
import data.dataset |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
import pandas |
|
|
from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision, MulticlassF1Score |
|
|
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights |
|
|
from torch import nn, optim |
|
|
|
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
from torch.utils.data import random_split |
|
|
|
|
|
import mlflow |
|
|
|
|
|
if __name__ == '__main__': |
|
|
mlflow.set_tracking_uri('http://localhost:5000') |
|
|
curr_date = datetime.datetime.now() |
|
|
os.mkdir(f"outputs/{curr_date}") |
|
|
|
|
|
|
|
|
with mlflow.start_run(): |
|
|
device = torch.device("cuda") |
|
|
augmentation_transforms = v2.Compose([ |
|
|
v2.RandomHorizontalFlip(), |
|
|
v2.RandomVerticalFlip(), |
|
|
v2.RandomGrayscale(), |
|
|
v2.RandomAutocontrast(), |
|
|
v2.RandomRotation(45), |
|
|
]).to("cuda") |
|
|
dataset = data.dataset.OrangeDataset("/home/jarric/orange_dataset/processed/FIELD IMAGES/") |
|
|
|
|
|
train_size = int(0.75 * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) |
|
|
|
|
|
|
|
|
batch_size = 32 |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=8, |
|
|
|
|
|
) |
|
|
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=8, |
|
|
|
|
|
) |
|
|
epochs = 100 |
|
|
|
|
|
mobilenet_v3_model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1.DEFAULT).to(device) |
|
|
|
|
|
for param in mobilenet_v3_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
mobilenet_v3_model.classifier[3] = Linear(in_features=1024, out_features=3, bias=True) |
|
|
mobilenet_v3_model.classifier[3].requires_grad = True |
|
|
|
|
|
mobilenet_v3_model.cuda() |
|
|
loss_fn = nn.CrossEntropyLoss().to(device) |
|
|
|
|
|
optimizer = optim.AdamW( |
|
|
mobilenet_v3_model.parameters(), |
|
|
lr=1e-4, |
|
|
weight_decay=1e-4) |
|
|
|
|
|
reporting_interval_train = 50 |
|
|
reporting_interval_val = 10 |
|
|
|
|
|
loss_fn = loss_fn.to(device) |
|
|
|
|
|
|
|
|
acc_metric = MulticlassAccuracy(num_classes=3).to(device) |
|
|
ap_metric = MulticlassAveragePrecision(num_classes=3, average="macro").to(device) |
|
|
f1_metric = MulticlassF1Score(num_classes=3).to(device) |
|
|
|
|
|
|
|
|
train_step = 0 |
|
|
val_step = 0 |
|
|
for epoch in tqdm(range(0, epochs)): |
|
|
train_loss = 0 |
|
|
avg_accuracy = 0 |
|
|
cur_iter = 0 |
|
|
average_precision = 0 |
|
|
f1_score_avg = 0 |
|
|
mobilenet_v3_model.train() |
|
|
for images, labels in tqdm(train_loader, leave=False): |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
images = augmentation_transforms(images) |
|
|
outputs = mobilenet_v3_model(images) |
|
|
loss = loss_fn(outputs, labels) |
|
|
train_loss += loss.item() |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
avg_accuracy += acc_metric(predicted, labels) |
|
|
average_precision += ap_metric(outputs, labels) |
|
|
f1_score_avg += f1_metric(predicted, labels) |
|
|
|
|
|
train_loss /= len(train_loader) |
|
|
avg_accuracy /= len(train_loader) |
|
|
average_precision /= len(train_loader) |
|
|
f1_score_avg /= len(train_loader) |
|
|
mlflow.log_metric("train_loss", train_loss, step=epoch) |
|
|
mlflow.log_metric("train_avg_accuracy", avg_accuracy, step=epoch) |
|
|
mlflow.log_metric("train_average_precision", average_precision, step=epoch) |
|
|
mlflow.log_metric("f1_score_avg", f1_score_avg, step=epoch) |
|
|
|
|
|
val_loss = 0 |
|
|
val_accuracy = 0 |
|
|
cur_iter = 0 |
|
|
average_precision_val = 0 |
|
|
f1_score_avg_val = 0 |
|
|
mobilenet_v3_model.eval() |
|
|
with torch.no_grad(): |
|
|
for images, labels in tqdm(val_loader, leave=False): |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
|
|
|
outputs = mobilenet_v3_model(images) |
|
|
loss = loss_fn(outputs, labels) |
|
|
val_loss += loss.item() |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
average_precision_val += ap_metric(outputs, labels) |
|
|
val_accuracy += acc_metric(predicted, labels) |
|
|
f1_score_avg_val += f1_metric(predicted, labels) |
|
|
|
|
|
val_loss /= len(val_loader) |
|
|
val_accuracy /= len(val_loader) |
|
|
average_precision_val /= len(val_loader) |
|
|
f1_score_avg_val /= len(val_loader) |
|
|
mlflow.log_metric("val_loss", val_loss, step=epoch) |
|
|
mlflow.log_metric("val_avg_accuracy", val_accuracy, step=epoch) |
|
|
mlflow.log_metric("val_average_precision", average_precision_val, step=epoch) |
|
|
mlflow.log_metric("val_f1_score_avg", f1_score_avg_val, step=epoch) |
|
|
|
|
|
|
|
|
torch.save(mobilenet_v3_model, f"outputs/{curr_date}/model_{epoch}_finetuned.pt") |
|
|
mlflow.log_artifact(f"outputs/{curr_date}/model_{epoch}_finetuned.pt") |
|
|
|