| | |
| |
|
| | import torch |
| |
|
| | from ultralytics.yolo.data import ClassificationDataset, build_dataloader |
| | from ultralytics.yolo.engine.validator import BaseValidator |
| | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER |
| | from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix |
| | from ultralytics.yolo.utils.plotting import plot_images |
| |
|
| |
|
| | class ClassificationValidator(BaseValidator): |
| |
|
| | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): |
| | """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar.""" |
| | super().__init__(dataloader, save_dir, pbar, args, _callbacks) |
| | self.args.task = 'classify' |
| | self.metrics = ClassifyMetrics() |
| |
|
| | def get_desc(self): |
| | """Returns a formatted string summarizing classification metrics.""" |
| | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') |
| |
|
| | def init_metrics(self, model): |
| | """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" |
| | self.names = model.names |
| | self.nc = len(model.names) |
| | self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify') |
| | self.pred = [] |
| | self.targets = [] |
| |
|
| | def preprocess(self, batch): |
| | """Preprocesses input batch and returns it.""" |
| | batch['img'] = batch['img'].to(self.device, non_blocking=True) |
| | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() |
| | batch['cls'] = batch['cls'].to(self.device) |
| | return batch |
| |
|
| | def update_metrics(self, preds, batch): |
| | """Updates running metrics with model predictions and batch targets.""" |
| | n5 = min(len(self.model.names), 5) |
| | self.pred.append(preds.argsort(1, descending=True)[:, :n5]) |
| | self.targets.append(batch['cls']) |
| |
|
| | def finalize_metrics(self, *args, **kwargs): |
| | """Finalizes metrics of the model such as confusion_matrix and speed.""" |
| | self.confusion_matrix.process_cls_preds(self.pred, self.targets) |
| | if self.args.plots: |
| | for normalize in True, False: |
| | self.confusion_matrix.plot(save_dir=self.save_dir, |
| | names=self.names.values(), |
| | normalize=normalize, |
| | on_plot=self.on_plot) |
| | self.metrics.speed = self.speed |
| | self.metrics.confusion_matrix = self.confusion_matrix |
| |
|
| | def get_stats(self): |
| | """Returns a dictionary of metrics obtained by processing targets and predictions.""" |
| | self.metrics.process(self.targets, self.pred) |
| | return self.metrics.results_dict |
| |
|
| | def build_dataset(self, img_path): |
| | return ClassificationDataset(root=img_path, args=self.args, augment=False) |
| |
|
| | def get_dataloader(self, dataset_path, batch_size): |
| | """Builds and returns a data loader for classification tasks with given parameters.""" |
| | dataset = self.build_dataset(dataset_path) |
| | return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) |
| |
|
| | def print_results(self): |
| | """Prints evaluation metrics for YOLO object detection model.""" |
| | pf = '%22s' + '%11.3g' * len(self.metrics.keys) |
| | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) |
| |
|
| | def plot_val_samples(self, batch, ni): |
| | """Plot validation image samples.""" |
| | plot_images(images=batch['img'], |
| | batch_idx=torch.arange(len(batch['img'])), |
| | cls=batch['cls'].squeeze(-1), |
| | fname=self.save_dir / f'val_batch{ni}_labels.jpg', |
| | names=self.names, |
| | on_plot=self.on_plot) |
| |
|
| | def plot_predictions(self, batch, preds, ni): |
| | """Plots predicted bounding boxes on input images and saves the result.""" |
| | plot_images(batch['img'], |
| | batch_idx=torch.arange(len(batch['img'])), |
| | cls=torch.argmax(preds, dim=1), |
| | fname=self.save_dir / f'val_batch{ni}_pred.jpg', |
| | names=self.names, |
| | on_plot=self.on_plot) |
| |
|
| |
|
| | def val(cfg=DEFAULT_CFG, use_python=False): |
| | """Validate YOLO model using custom data.""" |
| | model = cfg.model or 'yolov8n-cls.pt' |
| | data = cfg.data or 'mnist160' |
| |
|
| | args = dict(model=model, data=data) |
| | if use_python: |
| | from ultralytics import YOLO |
| | YOLO(model).val(**args) |
| | else: |
| | validator = ClassificationValidator(args=args) |
| | validator(model=args['model']) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | val() |
| |
|