File size: 3,974 Bytes
02c7167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import lightning.pytorch as pl
import config
from utils import (check_class_accuracy,get_evaluation_bboxes,mean_average_precision,plot_couple_examples)
from lightning.pytorch.callbacks import Callback

class PlotTestExamplesCallback(Callback):
    def __init__(self, every_n_epochs: int = 1) -> None:
        super().__init__()
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
            plot_couple_examples(
                model=pl_module,
                loader=pl_module.train_dataloader(),
                thresh=0.6,
                iou_thresh=0.5,
                anchors=pl_module.scaled_anchors,
            )


class CheckClassAccuracyCallback(pl.Callback):
    def __init__(
        self, train_every_n_epochs: int = 1, test_every_n_epochs: int = 3
    ) -> None:
        super().__init__()
        self.train_every_n_epochs = train_every_n_epochs
        self.test_every_n_epochs = test_every_n_epochs

    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        if (trainer.current_epoch + 1) % self.train_every_n_epochs == 0:
            class_acc, no_obj_acc, obj_acc = check_class_accuracy(
                model=pl_module,
                loader=pl_module.train_dataloader(),
                threshold=config.CONF_THRESHOLD,
            )
            pl_module.log_dict(
                {
                    "train_class_acc": class_acc,
                    "train_no_obj_acc": no_obj_acc,
                    "train_obj_acc": obj_acc,
                },
                logger=True,
            )
            print("Train Metrics")
            print(f"Epoch: {trainer.current_epoch}")
            print(f"Loss: {trainer.callback_metrics['train_loss_epoch']}")
            print(f"Class Accuracy: {class_acc:2f}%")
            print(f"No Object Accuracy: {no_obj_acc:2f}%")
            print(f"Object Accuracy: {obj_acc:2f}%")

        if (trainer.current_epoch + 1) % self.test_every_n_epochs == 0:
            class_acc, no_obj_acc, obj_acc = check_class_accuracy(
                model=pl_module,
                loader=pl_module.test_dataloader(),
                threshold=config.CONF_THRESHOLD,
            )
            pl_module.log_dict(
                {
                    "test_class_acc": class_acc,
                    "test_no_obj_acc": no_obj_acc,
                    "test_obj_acc": obj_acc,
                },
                logger=True,
            )

            print("Test Metrics")
            print(f"Class Accuracy: {class_acc:2f}%")
            print(f"No Object Accuracy: {no_obj_acc:2f}%")
            print(f"Object Accuracy: {obj_acc:2f}%")

class MAPCallback(pl.Callback):
    def __init__(self, every_n_epochs: int = 3) -> None:
        super().__init__()
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
            pred_boxes, true_boxes = get_evaluation_bboxes(
                loader=pl_module.test_dataloader(),
                model=pl_module,
                iou_threshold=config.NMS_IOU_THRESH,
                anchors=config.ANCHORS,
                threshold=config.CONF_THRESHOLD,
                device=config.DEVICE,
            )

            map_val = mean_average_precision(
                pred_boxes=pred_boxes,
                true_boxes=true_boxes,
                iou_threshold=config.MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=config.NUM_CLASSES,
            )
            print("MAP: ", map_val.item())
            pl_module.log(
                "MAP",
                map_val.item(),
                logger=True,
            )
            pl_module.train()