Spaces:
Sleeping
Sleeping
| import lightning.pytorch as pl | |
| import config | |
| from utils import (check_class_accuracy,get_evaluation_bboxes,mean_average_precision,plot_couple_examples) | |
| from lightning.pytorch.callbacks import Callback | |
| class PlotTestExamplesCallback(Callback): | |
| def __init__(self, every_n_epochs: int = 1) -> None: | |
| super().__init__() | |
| self.every_n_epochs = every_n_epochs | |
| def on_train_epoch_end( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
| ) -> None: | |
| if (trainer.current_epoch + 1) % self.every_n_epochs == 0: | |
| plot_couple_examples( | |
| model=pl_module, | |
| loader=pl_module.train_dataloader(), | |
| thresh=0.6, | |
| iou_thresh=0.5, | |
| anchors=pl_module.scaled_anchors, | |
| ) | |
| class CheckClassAccuracyCallback(pl.Callback): | |
| def __init__( | |
| self, train_every_n_epochs: int = 1, test_every_n_epochs: int = 3 | |
| ) -> None: | |
| super().__init__() | |
| self.train_every_n_epochs = train_every_n_epochs | |
| self.test_every_n_epochs = test_every_n_epochs | |
| def on_train_epoch_end( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
| ) -> None: | |
| if (trainer.current_epoch + 1) % self.train_every_n_epochs == 0: | |
| class_acc, no_obj_acc, obj_acc = check_class_accuracy( | |
| model=pl_module, | |
| loader=pl_module.train_dataloader(), | |
| threshold=config.CONF_THRESHOLD, | |
| ) | |
| pl_module.log_dict( | |
| { | |
| "train_class_acc": class_acc, | |
| "train_no_obj_acc": no_obj_acc, | |
| "train_obj_acc": obj_acc, | |
| }, | |
| logger=True, | |
| ) | |
| print("Train Metrics") | |
| print(f"Epoch: {trainer.current_epoch}") | |
| print(f"Loss: {trainer.callback_metrics['train_loss_epoch']}") | |
| print(f"Class Accuracy: {class_acc:2f}%") | |
| print(f"No Object Accuracy: {no_obj_acc:2f}%") | |
| print(f"Object Accuracy: {obj_acc:2f}%") | |
| if (trainer.current_epoch + 1) % self.test_every_n_epochs == 0: | |
| class_acc, no_obj_acc, obj_acc = check_class_accuracy( | |
| model=pl_module, | |
| loader=pl_module.test_dataloader(), | |
| threshold=config.CONF_THRESHOLD, | |
| ) | |
| pl_module.log_dict( | |
| { | |
| "test_class_acc": class_acc, | |
| "test_no_obj_acc": no_obj_acc, | |
| "test_obj_acc": obj_acc, | |
| }, | |
| logger=True, | |
| ) | |
| print("Test Metrics") | |
| print(f"Class Accuracy: {class_acc:2f}%") | |
| print(f"No Object Accuracy: {no_obj_acc:2f}%") | |
| print(f"Object Accuracy: {obj_acc:2f}%") | |
| class MAPCallback(pl.Callback): | |
| def __init__(self, every_n_epochs: int = 3) -> None: | |
| super().__init__() | |
| self.every_n_epochs = every_n_epochs | |
| def on_train_epoch_end( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
| ) -> None: | |
| if (trainer.current_epoch + 1) % self.every_n_epochs == 0: | |
| pred_boxes, true_boxes = get_evaluation_bboxes( | |
| loader=pl_module.test_dataloader(), | |
| model=pl_module, | |
| iou_threshold=config.NMS_IOU_THRESH, | |
| anchors=config.ANCHORS, | |
| threshold=config.CONF_THRESHOLD, | |
| device=config.DEVICE, | |
| ) | |
| map_val = mean_average_precision( | |
| pred_boxes=pred_boxes, | |
| true_boxes=true_boxes, | |
| iou_threshold=config.MAP_IOU_THRESH, | |
| box_format="midpoint", | |
| num_classes=config.NUM_CLASSES, | |
| ) | |
| print("MAP: ", map_val.item()) | |
| pl_module.log( | |
| "MAP", | |
| map_val.item(), | |
| logger=True, | |
| ) | |
| pl_module.train() |