MooreMuaMu's picture
Add SAMIHS ICH segmentation package
29aaa12 verified
Raw
History Blame Contribute Delete
11 kB
# -*- coding: utf-8 -*-
# this file is utilized to evaluate the models from mode: 2D-slice level
from torch.autograd import Variable
from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.nn.functional as F
import utils.metrics as metrics
from hausdorff import hausdorff_distance
from utils.generate_prompts import get_click_prompt, get_click_prompt_eval
import time
import pandas as pd
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as Func
from tqdm import tqdm
import cv2
import nibabel as nib # ★ 新
# ------------------------------- 可视化辅助函数 ------------------------------- #
def _overlay(gray, pred_mask, gt_mask=None, alpha=0.4):
gray_uint8 = (gray * 255).astype(np.uint8)
gray3 = cv2.cvtColor(gray_uint8, cv2.COLOR_GRAY2BGR)
over = gray3.copy()
if gt_mask is not None: # 绿色通道
over[gt_mask > 0, 1] = 255
if pred_mask is not None: # 红色通道
over[pred_mask > 0, 2] = 255
return cv2.addWeighted(gray3, 1 - alpha, over, alpha, 0)
def _color_overlay(gray, mask, color, alpha=0.8):
g3 = cv2.cvtColor((gray * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
color_layer = np.zeros_like(g3)
color_layer[mask > 0] = color
return cv2.addWeighted(color_layer, alpha, g3, 1 - alpha, 0)
def _two_panel(img_left, img_right):
return cv2.hconcat([img_left, img_right])
def _hstack_many(panels):
"""横向拼接多张同尺寸 BGR 图"""
assert len(panels) >= 2
h, w = panels[0].shape[:2]
for p in panels:
assert p.shape[:2] == (h, w), f"shape mismatch: {p.shape} vs {(h,w)}"
return cv2.hconcat(panels)
# ----------------------------------------------------------------------------- #
def eval_mask_slice2(valloader, model, criterion, opt, args):
if getattr(args, 'vis_dir', None):
os.makedirs(args.vis_dir, exist_ok=True)
model.eval()
val_losses, mean_dice = 0, 0
max_slice_number = opt.batch_size * (len(valloader) + 1)
# 初始化指标容器
fore_dice, hds_fore = [], []
dices = np.zeros((max_slice_number, opt.classes))
hds = np.zeros((max_slice_number, opt.classes))
ious, accs, ses, sps = np.zeros_like(dices), np.zeros_like(dices), np.zeros_like(dices), np.zeros_like(dices)
eval_number = 0
sum_time = 0
with tqdm(total=len(valloader), desc='Validation round', unit='batch', leave=False) as pbar:
for batch_idx, datapack in enumerate(valloader):
imgs = Variable(datapack['image'].to(dtype=torch.float32, device=opt.device))
image_filename = datapack['image_name']
pt = get_click_prompt_eval(datapack, opt)
# ============== DEBUG: prompt / bbox 体检 =================
if isinstance(pt, (tuple, list)):
_coords, _labels = pt
if not torch.isfinite(_coords).all():
print(f"[DEBUG][VAL] pt coords NaN/Inf at batch={batch_idx}")
if not torch.isfinite(_labels.float()).all():
print(f"[DEBUG][VAL] pt labels NaN/Inf at batch={batch_idx}")
else:
print(f"[DEBUG][VAL] pt type unexpected: {type(pt)} at batch={batch_idx}")
bbox = None # 你的原逻辑保持不变;若你传 bbox,可在此做 clamp 并打印
# =========================================================
with torch.no_grad():
start_time = time.time()
pred = model(imgs, pt, bbox=None)
print(f"pred mean = {pred['masks'].mean().item()}, std = {pred['masks'].std().item()}")
sum_time += (time.time() - start_time)
# ============== DEBUG: logits 体检 ========================
if not isinstance(pred, dict) or 'masks' not in pred:
print(f"[DEBUG][VAL] pred format unexpected at batch={batch_idx}: keys={list(pred.keys()) if isinstance(pred, dict) else type(pred)}")
else:
_logits = pred['masks']
if not torch.isfinite(_logits).all():
print(f"[DEBUG][VAL] logits NaN/Inf at batch={batch_idx}",
" min/max:", float(_logits.min()), float(_logits.max()))
# =========================================================
predict = torch.sigmoid(pred['masks'])
predict = Func.resize(predict, (512, 512), InterpolationMode.BILINEAR)
predict = predict.detach().cpu().numpy()[:, 0, :, :] > 0.5
b, h, w = predict.shape
has_gt = 'low_mask' in datapack and 'label' in datapack
if has_gt:
masks = Variable(datapack['low_mask'].to(dtype=torch.float32, device=opt.device))
label = Variable(datapack['label'].to(dtype=torch.float32, device=opt.device))
# ============== DEBUG: GT 体检 & 空掩膜比例 =============
dims = tuple(range(1, masks.ndim))
per_sample_sum = masks.sum(dim=dims)
empty_ratio = (per_sample_sum == 0).float().mean().item()
if empty_ratio > 0:
# print(f"[DEBUG][VAL] empty GT ratio={empty_ratio:.3f} at batch={batch_idx}")
pass
if not torch.isfinite(masks).all():
print(f"[DEBUG][VAL] masks NaN/Inf at batch={batch_idx}")
# =======================================================
gt = label.detach().cpu().numpy()[:, 0, :, :]
val_loss = criterion(pred, masks)
# ============== DEBUG: val loss 体检 ====================
if not torch.isfinite(val_loss):
print(f"[DEBUG][VAL] loss NaN/Inf at batch={batch_idx} "
f"(loss={val_loss})")
# 打印更多上下文
print(" imgs finite:", torch.isfinite(imgs).all().item(),
" min/max:", float(imgs.min()), float(imgs.max()))
print(" logits finite:", torch.isfinite(pred['masks']).all().item(),
" min/max:", float(pred['masks'].min()), float(pred['masks'].max()))
print(" masks finite:", torch.isfinite(masks).all().item(),
" sum per-sample:", per_sample_sum.tolist())
# =======================================================
val_losses += val_loss.item()
pbar.set_postfix(**{'loss (batch)': val_loss.item()})
else:
gt = [None] * b
for j in range(b):
pred_i = predict[j].astype(np.uint8)
if gt[j] is not None:
gt_i = (gt[j] > 0).astype(np.uint8)
# ---- 先计算指标(不依赖是否可视化) ----
# both-empty 保护
if pred_i.sum() + gt_i.sum() == 0:
dice_i = 1.0
hd_i = 0.0
iou, acc, se, sp = 1.0, 1.0, 1.0, 1.0
else:
dice_i = metrics.dice_coefficient(pred_i[None], gt_i[None])
iou, acc, se, sp = metrics.sespiou_coefficient2(pred_i[None], gt_i[None], all=False)
hd_i = hausdorff_distance(pred_i, gt_i)
# 写入容器
dices[eval_number + j, 1] = dice_i
ious [eval_number + j, 1] = iou
accs [eval_number + j, 1] = acc
ses [eval_number + j, 1] = se
sps [eval_number + j, 1] = sp
hds [eval_number + j, 1] = hd_i
if gt_i.sum() > 0:
fore_dice.append(dice_i)
hds_fore.append(hd_i)
# ---- 只有这里保留可视化 ----
if getattr(args, 'vis_dir', None):
base = os.path.splitext(image_filename[j])[0]
gray = imgs[j, 0].cpu().numpy()
gray = (gray - gray.min()) / (gray.ptp() + 1e-6)
if gray.shape != (h, w):
gray = cv2.resize(gray, (w, h), interpolation=cv2.INTER_LINEAR)
gray_bgr = cv2.cvtColor((gray * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
panel_pred = _color_overlay(gray, pred_i, color=(0, 0, 255), alpha=0.8)
if gt[j] is not None:
panel_gt = _color_overlay(gray, (gt[j] > 0).astype(np.uint8), color=(0, 255, 0), alpha=0.8)
# 三联图:原图 | GT 叠加 | Pred 叠加
combined = _hstack_many([gray_bgr, panel_gt, panel_pred])
else:
# 无 GT 时仍旧两联图:原图 | Pred 叠加
combined = _two_panel(gray_bgr, panel_pred)
cv2.imwrite(os.path.join(args.vis_dir, f"{base}_compare.png"), combined[:, :, ::-1])
eval_number += b
pbar.update()
if not has_gt:
print("⚠️ Warning: No ground-truth masks detected. Only predictions are visualized.")
return
# 指标统计
dices = dices[:eval_number, :]
hds = hds[:eval_number, :]
ious, accs, ses, sps = (ious[:eval_number, :], accs[:eval_number, :],
ses[:eval_number, :], sps[:eval_number, :])
val_losses /= (batch_idx + 1)
fore_dice_mean = np.mean(fore_dice)
print("fore_dice_mean", fore_dice_mean)
hds_fore_mean = np.mean(hds_fore)
dice_mean = np.mean(dices, axis=0)
print("dice_mean", dice_mean)
dices_std = np.std(dices, axis=0)
hd_mean = np.mean(hds, axis=0)
hd_std = np.std(hds, axis=0)
mean_dice = np.mean(dice_mean[1:])
mean_hdis = np.mean(hd_mean[1:])
print("test speed", eval_number / sum_time)
if opt.mode == "train":
return dices, fore_dice_mean, hds_fore_mean, val_losses
else:
iou_mean, iou_std = np.mean(ious, axis=0), np.std(ious, axis=0)
acc_mean, acc_std = np.mean(accs, axis=0), np.std(accs, axis=0)
se_mean, se_std = np.mean(ses, axis=0), np.std(ses, axis=0)
sp_mean, sp_std = np.mean(sps, axis=0), np.std(sps, axis=0)
return fore_dice_mean, hds_fore_mean, iou_mean, acc_mean, se_mean, sp_mean, dices_std[1], hd_std[1], iou_std, acc_std, se_std, sp_std
# ------------------------------------------------------------ #
def get_eval(valloader, model, criterion, opt, args):
if opt.eval_mode == "mask_slice":
return eval_mask_slice2(valloader, model, criterion, opt, args)
else:
raise RuntimeError("Could not find the eval mode:", opt.eval_mode)