import matplotlib.pyplot as plt import numpy as np import torch from skimage.morphology import label, skeletonize from skimage.util import view_as_windows from torchmetrics import MeanAbsoluteError, MeanSquaredError from torchmetrics.classification import ( BinaryAccuracy, BinaryAUROC, BinaryCohenKappa, BinaryF1Score, BinaryJaccardIndex, BinaryPrecision, BinaryRecall, BinarySpecificity ) from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.segmentation import DiceScore from tqdm.auto import tqdm def remove_junctions(skel: np.ndarray) -> np.ndarray: """Remove junction points from a binary skeleton.""" skel = skel.astype(np.uint8) mask = np.zeros_like(skel) windows = view_as_windows(skel, (3, 3)) for i in range(windows.shape[0]): for j in range(windows.shape[1]): if windows[i, j].sum() > 4: mask[i:i+3, j:j+3] = 1 return skel * (1 - mask) def fracture_similarity(pred_mask: torch.Tensor, true_mask: torch.Tensor) -> float: """Compute similarity score between predicted and true fracture masks.""" pred_skel = skeletonize((pred_mask > 0.1).cpu().numpy()) true_skel = skeletonize((true_mask > 0.1).cpu().numpy()) pred_clean = remove_junctions(pred_skel) true_clean = remove_junctions(true_skel) pred_labeled = label(pred_clean) true_labeled = label(true_clean) pred_lengths = np.bincount(pred_labeled.ravel())[1:] true_lengths = np.bincount(true_labeled.ravel())[1:] bins = np.linspace(0, 260, 20) pred_hist, _ = np.histogram(pred_lengths, bins=bins) true_hist, _ = np.histogram(true_lengths, bins=bins) pred_hist = pred_hist + 1e-6 true_hist = true_hist + 1e-6 chi_dist = 0.5 * np.sum((pred_hist - true_hist)**2 / (pred_hist + true_hist)) return chi_dist def train_loop(model, optimizer, criterion, train_loader, device='cpu', mdl=None): """Train the model for one epoch.""" running_loss = 0 model = model.to(device) model.train() pbar = tqdm(train_loader, desc="Iterating over train data") for images, labels in pbar: images, labels = images.to(device), labels.to(device) out = model(images)['out'] if mdl == 'fcn_resnet101' else model(images) loss = criterion(out, labels) running_loss += loss.item() * images.shape[0] optimizer.zero_grad() loss.backward() optimizer.step() running_loss /= len(train_loader.sampler) return running_loss def eval_loop(model, scheduler, criterion, eval_loader, threshold=0.5, device='cpu', mdl=None, ignore_index=None): """Evaluate the model on a validation or test dataset.""" running_loss = 0 model.eval() if ignore_index not in [0, 1]: ignore_index = None with torch.no_grad(): # Metrics acc_metric = BinaryAccuracy(ignore_index=ignore_index).to(device) f1_metric = BinaryF1Score(ignore_index=ignore_index).to(device) prec_metric = BinaryPrecision(ignore_index=ignore_index).to(device) rec_metric = BinaryRecall(ignore_index=ignore_index).to(device) spec_metric = BinarySpecificity(ignore_index=ignore_index).to(device) auroc_metric = BinaryAUROC(ignore_index=ignore_index).to(device) iou_metric = BinaryJaccardIndex(ignore_index=ignore_index).to(device) dice_metric = DiceScore(num_classes=1, average="micro", aggregation_level='global').to(device) ck_metric = BinaryCohenKappa().to(device) mse_metric = MeanSquaredError().to(device) ae_metric = MeanAbsoluteError().to(device) psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device) ssim_metric = StructuralSimilarityIndexMeasure().to(device) fracture_sim_scores = [] pbar = tqdm(eval_loader, desc='Iterating over evaluation/test data') for imgs, labels in pbar: imgs, labels = imgs.to(device), labels.to(device) out = model(imgs)['out'] if mdl == 'fcn_resnet101' else model(imgs) loss = criterion(out, labels) running_loss += loss.item() * imgs.shape[0] predicted = out if mdl == 'Segformer': predicted[predicted > 0.99] = 0. predicted_clf = (out > threshold).float() labels_clf = (labels > 0.).float() labels = labels.float() # Compute metrics acc_metric(predicted_clf, labels_clf) f1_metric(predicted_clf, labels_clf) prec_metric(predicted_clf, labels_clf) rec_metric(predicted_clf, labels_clf) spec_metric(predicted_clf, labels_clf) if labels_clf.numel() > 0 and labels_clf.min() != labels_clf.max(): auroc_metric(predicted_clf, labels_clf) dice_metric(predicted_clf, labels_clf) iou_metric(predicted_clf, labels_clf) ck_metric(predicted_clf, labels_clf) mse_metric(predicted, labels) psnr_metric(predicted, labels) ssim_metric(predicted, labels) ae_metric(predicted, labels) for i in range(imgs.shape[0]): pred_mask = predicted_clf[i, 0].detach().cpu() true_mask = labels_clf[i, 0].detach().cpu() fracture_sim_scores.append(fracture_similarity(pred_mask, true_mask)) avg_fracture_sim = float(np.mean(fracture_sim_scores)) if fracture_sim_scores else float('nan') return { 'mse': mse_metric.compute().item(), 'psnr': psnr_metric.compute().item(), 'ssim': ssim_metric.compute().item(), 'ae': ae_metric.compute().item(), 'acc': acc_metric.compute().item(), 'f1': f1_metric.compute().item(), 'prec': prec_metric.compute().item(), 'rec': rec_metric.compute().item(), 'spec': spec_metric.compute().item(), 'dice': dice_metric.compute().item(), 'iou': iou_metric.compute().item(), 'ck': ck_metric.compute().item(), 'roc_auc': auroc_metric.compute().item(), 'loss': running_loss / len(eval_loader.sampler), 'frac_sim': avg_fracture_sim, } def eval_single(gt, pred, threshold=0.5, device="cpu", ignore_index=None): """Evaluate metrics for a single prediction and ground truth pair.""" gt = torch.from_numpy(gt).to(device).float().unsqueeze(0).unsqueeze(0) pred = torch.from_numpy(pred).to(device).float().unsqueeze(0).unsqueeze(0) pred_clf = (pred > threshold).long() gt_clf = (gt > 0).long() if ignore_index not in [0, 1]: ignore_index = None # Metrics acc_metric = BinaryAccuracy(ignore_index=ignore_index).to(device) f1_metric = BinaryF1Score(ignore_index=ignore_index).to(device) prec_metric = BinaryPrecision(ignore_index=ignore_index).to(device) rec_metric = BinaryRecall(ignore_index=ignore_index).to(device) spec_metric = BinarySpecificity(ignore_index=ignore_index).to(device) auroc_metric = BinaryAUROC(ignore_index=ignore_index).to(device) iou_metric = BinaryJaccardIndex(ignore_index=ignore_index).to(device) dice_metric = DiceScore(num_classes=1, average="micro").to(device) ck_metric = BinaryCohenKappa().to(device) mse_metric = MeanSquaredError().to(device) ae_metric = MeanAbsoluteError().to(device) psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device) ssim_metric = StructuralSimilarityIndexMeasure().to(device) # Compute metrics acc_metric(pred_clf, gt_clf) f1_metric(pred_clf, gt_clf) prec_metric(pred_clf, gt_clf) rec_metric(pred_clf, gt_clf) spec_metric(pred_clf, gt_clf) if gt_clf.numel() > 0 and gt_clf.min() != gt_clf.max(): auroc_metric(pred, gt_clf.int()) dice_metric(pred_clf, gt_clf) iou_metric(pred_clf, gt_clf) ck_metric(pred_clf, gt_clf) mse_metric(pred, gt) psnr_metric(pred, gt) ssim_metric(pred, gt) ae_metric(pred, gt) return { 'mse': mse_metric.compute().item(), 'psnr': psnr_metric.compute().item(), 'ssim': ssim_metric.compute().item(), 'ae': ae_metric.compute().item(), 'acc': acc_metric.compute().item(), 'f1': f1_metric.compute().item(), 'prec': prec_metric.compute().item(), 'rec': rec_metric.compute().item(), 'spec': spec_metric.compute().item(), 'dice': dice_metric.compute().item(), 'iou': iou_metric.compute().item(), 'ck': ck_metric.compute().item(), 'roc_auc': auroc_metric.compute().item(), } def save_metrics(metrics: dict, kind: str, writer, epoch: int): """Log metrics to a TensorBoard writer.""" writer.add_scalar(f"Loss/{kind}", metrics['loss'], epoch) writer.add_scalar(f"ACC/{kind}", metrics['acc'], epoch) writer.add_scalar(f"F1/{kind}", metrics['f1'], epoch) writer.add_scalar(f"PREC/{kind}", metrics['prec'], epoch) writer.add_scalar(f"REC/{kind}", metrics['rec'], epoch) writer.add_scalar(f"ROC_AUC/{kind}", metrics['roc_auc'], epoch) writer.add_scalar(f"MSE/{kind}", metrics['mse'], epoch) writer.add_scalar(f"PSNR/{kind}", metrics['psnr'], epoch) writer.add_scalar(f"SSIM/{kind}", metrics['ssim'], epoch) writer.add_scalar(f"SPEC/{kind}", metrics['spec'], epoch) writer.add_scalar(f"DICE/{kind}", metrics['dice'], epoch) writer.add_scalar(f"AE/{kind}", metrics['ae'], epoch) writer.add_scalar(f"IoU/{kind}", metrics['iou'], epoch) writer.flush()