Spaces:
Sleeping
Sleeping
File size: 3,974 Bytes
02c7167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import lightning.pytorch as pl
import config
from utils import (check_class_accuracy,get_evaluation_bboxes,mean_average_precision,plot_couple_examples)
from lightning.pytorch.callbacks import Callback
class PlotTestExamplesCallback(Callback):
def __init__(self, every_n_epochs: int = 1) -> None:
super().__init__()
self.every_n_epochs = every_n_epochs
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
plot_couple_examples(
model=pl_module,
loader=pl_module.train_dataloader(),
thresh=0.6,
iou_thresh=0.5,
anchors=pl_module.scaled_anchors,
)
class CheckClassAccuracyCallback(pl.Callback):
def __init__(
self, train_every_n_epochs: int = 1, test_every_n_epochs: int = 3
) -> None:
super().__init__()
self.train_every_n_epochs = train_every_n_epochs
self.test_every_n_epochs = test_every_n_epochs
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
if (trainer.current_epoch + 1) % self.train_every_n_epochs == 0:
class_acc, no_obj_acc, obj_acc = check_class_accuracy(
model=pl_module,
loader=pl_module.train_dataloader(),
threshold=config.CONF_THRESHOLD,
)
pl_module.log_dict(
{
"train_class_acc": class_acc,
"train_no_obj_acc": no_obj_acc,
"train_obj_acc": obj_acc,
},
logger=True,
)
print("Train Metrics")
print(f"Epoch: {trainer.current_epoch}")
print(f"Loss: {trainer.callback_metrics['train_loss_epoch']}")
print(f"Class Accuracy: {class_acc:2f}%")
print(f"No Object Accuracy: {no_obj_acc:2f}%")
print(f"Object Accuracy: {obj_acc:2f}%")
if (trainer.current_epoch + 1) % self.test_every_n_epochs == 0:
class_acc, no_obj_acc, obj_acc = check_class_accuracy(
model=pl_module,
loader=pl_module.test_dataloader(),
threshold=config.CONF_THRESHOLD,
)
pl_module.log_dict(
{
"test_class_acc": class_acc,
"test_no_obj_acc": no_obj_acc,
"test_obj_acc": obj_acc,
},
logger=True,
)
print("Test Metrics")
print(f"Class Accuracy: {class_acc:2f}%")
print(f"No Object Accuracy: {no_obj_acc:2f}%")
print(f"Object Accuracy: {obj_acc:2f}%")
class MAPCallback(pl.Callback):
def __init__(self, every_n_epochs: int = 3) -> None:
super().__init__()
self.every_n_epochs = every_n_epochs
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
pred_boxes, true_boxes = get_evaluation_bboxes(
loader=pl_module.test_dataloader(),
model=pl_module,
iou_threshold=config.NMS_IOU_THRESH,
anchors=config.ANCHORS,
threshold=config.CONF_THRESHOLD,
device=config.DEVICE,
)
map_val = mean_average_precision(
pred_boxes=pred_boxes,
true_boxes=true_boxes,
iou_threshold=config.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=config.NUM_CLASSES,
)
print("MAP: ", map_val.item())
pl_module.log(
"MAP",
map_val.item(),
logger=True,
)
pl_module.train() |