Spaces:
Sleeping
Sleeping
| """ | |
| Main file for training Yolo model on Pascal VOC and COCO dataset | |
| """ | |
| import config | |
| import torch | |
| import torch.optim as optim | |
| from model import YOLOv3 | |
| from tqdm import tqdm | |
| from utils import ( | |
| mean_average_precision, | |
| cells_to_bboxes, | |
| get_evaluation_bboxes, | |
| save_checkpoint, | |
| load_checkpoint, | |
| check_class_accuracy, | |
| get_loaders, | |
| plot_couple_examples | |
| ) | |
| from loss import YoloLoss | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| torch.backends.cudnn.benchmark = True | |
| def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors): | |
| loop = tqdm(train_loader, leave=True) | |
| losses = [] | |
| for batch_idx, (x, y) in enumerate(loop): | |
| x = x.to(config.DEVICE) | |
| y0, y1, y2 = ( | |
| y[0].to(config.DEVICE), | |
| y[1].to(config.DEVICE), | |
| y[2].to(config.DEVICE), | |
| ) | |
| with torch.cuda.amp.autocast(): | |
| out = model(x) | |
| loss = ( | |
| loss_fn(out[0], y0, scaled_anchors[0]) | |
| + loss_fn(out[1], y1, scaled_anchors[1]) | |
| + loss_fn(out[2], y2, scaled_anchors[2]) | |
| ) | |
| losses.append(loss.item()) | |
| optimizer.zero_grad() | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| # update progress bar | |
| mean_loss = sum(losses) / len(losses) | |
| loop.set_postfix(loss=mean_loss) | |
| def main(): | |
| model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) | |
| optimizer = optim.Adam( | |
| model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY | |
| ) | |
| loss_fn = YoloLoss() | |
| scaler = torch.cuda.amp.GradScaler() | |
| train_loader, test_loader, train_eval_loader = get_loaders( | |
| train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv" | |
| ) | |
| if config.LOAD_MODEL: | |
| load_checkpoint( | |
| config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE | |
| ) | |
| scaled_anchors = ( | |
| torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ).to(config.DEVICE) | |
| for epoch in range(config.NUM_EPOCHS): | |
| #plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors) | |
| train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors) | |
| #if config.SAVE_MODEL: | |
| # save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar") | |
| #print(f"Currently epoch {epoch}") | |
| #print("On Train Eval loader:") | |
| #print("On Train loader:") | |
| #check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD) | |
| if epoch > 0 and epoch % 3 == 0: | |
| check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD) | |
| pred_boxes, true_boxes = get_evaluation_bboxes( | |
| test_loader, | |
| model, | |
| iou_threshold=config.NMS_IOU_THRESH, | |
| anchors=config.ANCHORS, | |
| threshold=config.CONF_THRESHOLD, | |
| ) | |
| mapval = mean_average_precision( | |
| pred_boxes, | |
| true_boxes, | |
| iou_threshold=config.MAP_IOU_THRESH, | |
| box_format="midpoint", | |
| num_classes=config.NUM_CLASSES, | |
| ) | |
| print(f"MAP: {mapval.item()}") | |
| model.train() | |
| if __name__ == "__main__": | |
| main() |