| import numpy as np | |
| from tqdm import tqdm | |
| from sklearn.metrics import confusion_matrix | |
| import torch | |
| from utils.metrics import compute_metrics | |
| from inference.sliding_window import batch_sliding_window_inference | |
| def eval_epoch(model, loader, criterion, device, crop_size, stride, num_classes): | |
| """ | |
| Evaluates the model in validation mode using sliding window inference. | |
| """ | |
| model.eval() | |
| conf_mat = np.zeros((num_classes, num_classes), dtype=int) | |
| pbar = tqdm(loader, desc="Validation", leave=False) | |
| with torch.no_grad(): | |
| for images, masks in pbar: | |
| masks = masks.cpu().numpy() | |
| preds = batch_sliding_window_inference(images, model, device, crop_size, stride).cpu().numpy() | |
| for b in range(images.size(0)): | |
| true_b = masks[b] | |
| pred_b = preds[b] | |
| valid = true_b != criterion.ignore_index | |
| conf_mat += confusion_matrix(true_b[valid], pred_b[valid], labels=list(range(num_classes))) | |
| ious, miou, _, mf1 = compute_metrics(conf_mat) | |
| return ious, miou, mf1 | |