File size: 5,778 Bytes
e81e6d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import datetime
import os
from torch.nn import Linear
from torchvision.transforms import v2
import data.dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
import pandas
from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision, MulticlassF1Score
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torch import nn, optim
import torch
from tqdm import tqdm
from torch.utils.data import random_split
import mlflow
if __name__ == '__main__':
mlflow.set_tracking_uri('http://localhost:5000')
curr_date = datetime.datetime.now()
os.mkdir(f"outputs/{curr_date}")
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
with mlflow.start_run():
device = torch.device("cuda")
augmentation_transforms = v2.Compose([
v2.RandomHorizontalFlip(),
v2.RandomVerticalFlip(),
v2.RandomGrayscale(),
v2.RandomAutocontrast(),
v2.RandomRotation(45),
]).to("cuda")
dataset = data.dataset.OrangeDataset("/home/jarric/orange_dataset/processed/FIELD IMAGES/")
train_size = int(0.75 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
)
epochs = 100
mobilenet_v3_model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1.DEFAULT).to(device)
for param in mobilenet_v3_model.parameters(): # freeze layers
param.requires_grad = False
mobilenet_v3_model.classifier[3] = Linear(in_features=1024, out_features=3, bias=True)
mobilenet_v3_model.classifier[3].requires_grad = True
mobilenet_v3_model.cuda()
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(
mobilenet_v3_model.parameters(),
lr=1e-4,
weight_decay=1e-4)
reporting_interval_train = 50
reporting_interval_val = 10
loss_fn = loss_fn.to(device)
# metrics
acc_metric = MulticlassAccuracy(num_classes=3).to(device)
ap_metric = MulticlassAveragePrecision(num_classes=3, average="macro").to(device)
f1_metric = MulticlassF1Score(num_classes=3).to(device)
train_step = 0
val_step = 0
for epoch in tqdm(range(0, epochs)):
train_loss = 0
avg_accuracy = 0
cur_iter = 0
average_precision = 0
f1_score_avg = 0
mobilenet_v3_model.train()
for images, labels in tqdm(train_loader, leave=False):
images, labels = images.to(device), labels.to(device)
images = augmentation_transforms(images)
outputs = mobilenet_v3_model(images)
loss = loss_fn(outputs, labels)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
avg_accuracy += acc_metric(predicted, labels)
average_precision += ap_metric(outputs, labels)
f1_score_avg += f1_metric(predicted, labels)
train_loss /= len(train_loader)
avg_accuracy /= len(train_loader)
average_precision /= len(train_loader)
f1_score_avg /= len(train_loader)
mlflow.log_metric("train_loss", train_loss, step=epoch)
mlflow.log_metric("train_avg_accuracy", avg_accuracy, step=epoch)
mlflow.log_metric("train_average_precision", average_precision, step=epoch)
mlflow.log_metric("f1_score_avg", f1_score_avg, step=epoch)
val_loss = 0
val_accuracy = 0
cur_iter = 0
average_precision_val = 0
f1_score_avg_val = 0
mobilenet_v3_model.eval()
with torch.no_grad():
for images, labels in tqdm(val_loader, leave=False):
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = mobilenet_v3_model(images)
loss = loss_fn(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
average_precision_val += ap_metric(outputs, labels)
val_accuracy += acc_metric(predicted, labels)
f1_score_avg_val += f1_metric(predicted, labels)
val_loss /= len(val_loader)
val_accuracy /= len(val_loader)
average_precision_val /= len(val_loader)
f1_score_avg_val /= len(val_loader)
mlflow.log_metric("val_loss", val_loss, step=epoch)
mlflow.log_metric("val_avg_accuracy", val_accuracy, step=epoch)
mlflow.log_metric("val_average_precision", average_precision_val, step=epoch)
mlflow.log_metric("val_f1_score_avg", f1_score_avg_val, step=epoch)
torch.save(mobilenet_v3_model, f"outputs/{curr_date}/model_{epoch}_finetuned.pt")
mlflow.log_artifact(f"outputs/{curr_date}/model_{epoch}_finetuned.pt")
|