|
|
import config |
|
|
import torch |
|
|
import torch.optim as optim |
|
|
|
|
|
from model import YOLOv3 |
|
|
from tqdm import tqdm |
|
|
from utils import ( |
|
|
mean_average_precision, |
|
|
cells_to_bboxes, |
|
|
get_evaluation_bboxes, |
|
|
save_checkpoint, |
|
|
load_checkpoint, |
|
|
check_class_accuracy, |
|
|
get_loaders, |
|
|
plot_couple_examples |
|
|
) |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
from loss import YoloLoss |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from pytorch_lightning import LightningModule |
|
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
|
from pytorch_lightning.callbacks.progress import TQDMProgressBar |
|
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
from torchmetrics.functional import accuracy |
|
|
|
|
|
class LitYolo(LightningModule): |
|
|
def __init__(self, batch_size=64): |
|
|
super().__init__() |
|
|
|
|
|
self.lr = config.LEARNING_RATE |
|
|
self.weight_decay =config.WEIGHT_DECAY |
|
|
self.model = YOLOv3(num_classes=config.NUM_CLASSES) |
|
|
self.save_hyperparameters() |
|
|
self.optimizer = optim.Adam( |
|
|
self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay |
|
|
) |
|
|
self.scaler = torch.cuda.amp.GradScaler() |
|
|
self.loss_fn = YoloLoss() |
|
|
self.scaled_anchors = ( |
|
|
torch.tensor(config.ANCHORS) |
|
|
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) |
|
|
).to(config.DEVICE) |
|
|
self.losses =[] |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
x, y = batch |
|
|
|
|
|
x = x.to(config.DEVICE) |
|
|
y0, y1, y2 = ( |
|
|
y[0].to(config.DEVICE), |
|
|
y[1].to(config.DEVICE), |
|
|
y[2].to(config.DEVICE), |
|
|
) |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
out = self.model(x) |
|
|
loss = ( |
|
|
self.loss_fn(out[0], y0, self.scaled_anchors[0]) |
|
|
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) |
|
|
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]) |
|
|
) |
|
|
self.losses.append(loss.item()) |
|
|
self.optimizer.zero_grad() |
|
|
self.scaler.scale(loss).backward(retain_graph=True) |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
|
|
|
mean_loss = sum(self.losses) / len(self.losses) |
|
|
|
|
|
|
|
|
self.log("mean_loss = ", mean_loss, prog_bar=True) |
|
|
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_end(self): |
|
|
epoch = self.trainer.current_epoch + 1 |
|
|
print(f"Currently epoch {epoch-1}") |
|
|
if config.SAVE_MODEL: |
|
|
save_checkpoint(self.model, self.optimizer, filename=config.CHECKPOINT_FILE) |
|
|
if epoch > 1 and epoch % 10 == 0 |
|
|
plot_couple_examples(self.model, self.test_dataloader(), 0.6, 0.5, self.scaled_anchors) |
|
|
print(f"Currently epoch {epoch-1}") |
|
|
print("On Train loader:") |
|
|
check_class_accuracy(self.model, self.train_dataloader(), threshold=config.CONF_THRESHOLD) |
|
|
if epoch > 30 and epoch % 8 == 0: |
|
|
check_class_accuracy(self.model, self.test_dataloader(), threshold=config.CONF_THRESHOLD) |
|
|
pred_boxes, true_boxes = get_evaluation_bboxes( |
|
|
self.test_dataloader(), |
|
|
self.model, |
|
|
iou_threshold=config.NMS_IOU_THRESH, |
|
|
anchors=config.ANCHORS, |
|
|
threshold=config.CONF_THRESHOLD, |
|
|
) |
|
|
mapval = mean_average_precision( |
|
|
pred_boxes, |
|
|
true_boxes, |
|
|
iou_threshold=config.MAP_IOU_THRESH, |
|
|
box_format="midpoint", |
|
|
num_classes=config.NUM_CLASSES, |
|
|
) |
|
|
print(f"MAP: {mapval.item()}") |
|
|
self.losses =[] |
|
|
self.model.train() |
|
|
|
|
|
|
|
|
|
|
|
def lr_finder(self, num_iter=50): |
|
|
from torch_lr_finder import LRFinder |
|
|
|
|
|
def criterion(out, y): |
|
|
y0, y1, y2 = ( |
|
|
y[0].to(config.DEVICE), |
|
|
y[1].to(config.DEVICE), |
|
|
y[2].to(config.DEVICE), |
|
|
) |
|
|
loss = ( |
|
|
self.loss_fn(out[0], y0, self.scaled_anchors[0]) |
|
|
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) |
|
|
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]) |
|
|
) |
|
|
return loss |
|
|
|
|
|
lr_finder = LRFinder(self.model, self.optimizer, criterion, device=config.DEVICE) |
|
|
lr_finder.range_test(self.train_dataloader(), end_lr=1, num_iter=num_iter, step_mode="exp") |
|
|
ax, suggested_lr = lr_finder.plot() |
|
|
lr_finder.reset() |
|
|
return suggested_lr |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
|
|
|
|
|
|
suggested_lr = 6.25E-03 |
|
|
|
|
|
steps_per_epoch = len(self.train_dataloader()) |
|
|
scheduler_dict = { |
|
|
"scheduler": OneCycleLR( |
|
|
self.optimizer, max_lr=suggested_lr, |
|
|
steps_per_epoch=steps_per_epoch, |
|
|
epochs=self.trainer.max_epochs, |
|
|
pct_start=5/self.trainer.max_epochs, |
|
|
three_phase=False, |
|
|
div_factor=80, |
|
|
final_div_factor=400, |
|
|
anneal_strategy='linear', |
|
|
), |
|
|
"interval": "step", |
|
|
} |
|
|
return {"optimizer": self.optimizer,"lr_scheduler": scheduler_dict} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(self): |
|
|
|
|
|
|
|
|
from dataset import YOLODataset |
|
|
IMAGE_SIZE = config.IMAGE_SIZE |
|
|
train_csv_path=config.DATASET + "/train.csv" |
|
|
test_csv_path=config.DATASET + "/test.csv" |
|
|
|
|
|
self.train_dataset = YOLODataset( |
|
|
train_csv_path, |
|
|
transform=config.train_transforms, |
|
|
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], |
|
|
img_dir=config.IMG_DIR, |
|
|
label_dir=config.LABEL_DIR, |
|
|
anchors=config.ANCHORS, |
|
|
) |
|
|
|
|
|
self.test_dataset = YOLODataset( |
|
|
test_csv_path, |
|
|
transform=config.test_transforms, |
|
|
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], |
|
|
img_dir=config.IMG_DIR, |
|
|
label_dir=config.LABEL_DIR, |
|
|
anchors=config.ANCHORS, |
|
|
) |
|
|
|
|
|
self.val_dataset = YOLODataset( |
|
|
train_csv_path, |
|
|
transform=config.test_transforms, |
|
|
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], |
|
|
img_dir=config.IMG_DIR, |
|
|
label_dir=config.LABEL_DIR, |
|
|
anchors=config.ANCHORS, |
|
|
) |
|
|
|
|
|
if config.LOAD_MODEL: |
|
|
load_checkpoint( |
|
|
config.CHECKPOINT_FILE, self.model, self.optimizer, config.LEARNING_RATE) |
|
|
|
|
|
|
|
|
def setup(self, stage=None): |
|
|
pass |
|
|
|
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader( |
|
|
dataset=self.train_dataset, |
|
|
batch_size=config.BATCH_SIZE, |
|
|
num_workers=config.NUM_WORKERS, |
|
|
pin_memory=config.PIN_MEMORY, |
|
|
persistent_workers=True, |
|
|
shuffle=True, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader( |
|
|
dataset=self.val_dataset, |
|
|
batch_size=config.BATCH_SIZE, |
|
|
num_workers=config.NUM_WORKERS, |
|
|
pin_memory=config.PIN_MEMORY, |
|
|
persistent_workers=True, |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
|
|
|
def test_dataloader(self): |
|
|
return DataLoader( |
|
|
dataset=self.test_dataset, |
|
|
batch_size=config.BATCH_SIZE, |
|
|
num_workers=config.NUM_WORKERS, |
|
|
pin_memory=config.PIN_MEMORY, |
|
|
persistent_workers=True, |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
) |
|
|
|