| import torch |
| from torch import nn |
| from tqdm import tqdm |
| import prettytable |
| import time |
| import os |
| import multiprocessing.pool as mpp |
| import multiprocessing as mp |
|
|
| from train import * |
|
|
| import argparse |
| from utils.config import Config |
| from tools.mask_convert import mask_save |
| import numpy as np |
| import csv |
|
|
| |
| class PRHistogram: |
| |
| |
| |
| |
| def __init__(self, nbins: int = 1000): |
| import numpy as _np |
| self.nbins = int(nbins) |
| self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64) |
| self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64) |
| self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1) |
|
|
| def update(self, probs, mask): |
| import numpy as _np |
| p = probs.detach().float().cpu().numpy().ravel() |
| g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8) |
| pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges) |
| neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges) |
| self.pos_hist += pos_counts |
| self.neg_hist += neg_counts |
|
|
| def compute_curve(self): |
| import numpy as _np |
| |
| pos_cum = _np.cumsum(self.pos_hist[::-1]) |
| neg_cum = _np.cumsum(self.neg_hist[::-1]) |
| TP = pos_cum |
| FP = neg_cum |
| FN = self.pos_hist.sum() - TP |
| TN = None |
|
|
| denom_prec = _np.maximum(TP + FP, 1) |
| denom_rec = _np.maximum(TP + FN, 1) |
| precision = TP / denom_prec |
| recall = TP / denom_rec |
|
|
| |
| denom_f1 = _np.maximum(precision + recall, 1e-12) |
| f1 = 2.0 * precision * recall / denom_f1 |
|
|
| |
| denom_iou = _np.maximum(TP + FP + FN, 1) |
| iou = TP / denom_iou |
|
|
| thresholds = self.bin_edges[::-1][1:] |
| return thresholds, precision, recall, f1, iou, TP, FP, FN |
|
|
| def export_csv(self, save_path: str): |
| thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve() |
| import numpy as _np, os as _os |
| _os.makedirs(_os.path.dirname(save_path), exist_ok=True) |
| _np.savetxt( |
| save_path, |
| _np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]), |
| delimiter=",", |
| header="threshold,precision,recall,f1,iou,TP,FP,FN", |
| comments="" |
| ) |
| return save_path |
|
|
| |
| _PR = None |
|
|
| def pr_init(nbins: int = 1000): |
| global _PR |
| if _PR is None: |
| _PR = PRHistogram(nbins=nbins) |
| return _PR |
|
|
| def pr_update_from_outputs(raw_predictions, mask, cfg): |
| |
| |
| |
| |
| |
| |
| import torch |
| global _PR |
| if _PR is None: |
| _PR = PRHistogram(nbins=1000) |
|
|
| if getattr(cfg, 'argmax', False): |
| logits = raw_predictions |
| if logits.dim() == 4 and logits.size(1) >= 2: |
| probs = torch.softmax(logits, dim=1)[:, 1, :, :] |
| else: |
| probs = torch.sigmoid(logits.squeeze(1)) |
| else: |
| if getattr(cfg, 'net', '') == 'maskcd': |
| if isinstance(raw_predictions, (list, tuple)): |
| logits = raw_predictions[0] |
| else: |
| logits = raw_predictions |
| probs = torch.sigmoid(logits).squeeze(1) |
| else: |
| logits = raw_predictions |
| if logits.dim() == 4 and logits.size(1) == 1: |
| logits = logits.squeeze(1) |
| probs = torch.sigmoid(logits) |
|
|
| if mask.dim() == 4 and mask.size(1) == 1: |
| mask_ = mask.squeeze(1) |
| else: |
| mask_ = mask |
| _PR.update(probs, (mask_ > 0).to(probs.dtype)) |
|
|
| def pr_export(base_dir: str, cfg): |
| |
| import os |
| global _PR |
| if _PR is None: |
| return None |
| save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv") |
| out = _PR.export_csv(save_path) |
| print(f"[PR] saved: {out}") |
| return out |
| |
|
|
| |
| def _safe_div(a, b, eps=1e-12): |
| return a / max(b, eps) |
|
|
| def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray): |
| """ |
| pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W] |
| 返回: dict 包含 TP/FP/TN/FN 与各类指标 |
| """ |
| pred_bin = (pred_np > 0).astype(np.uint8) |
| gt_bin = (gt_np > 0).astype(np.uint8) |
|
|
| TP = int(((pred_bin == 1) & (gt_bin == 1)).sum()) |
| FP = int(((pred_bin == 1) & (gt_bin == 0)).sum()) |
| TN = int(((pred_bin == 0) & (gt_bin == 0)).sum()) |
| FN = int(((pred_bin == 0) & (gt_bin == 1)).sum()) |
|
|
| precision = _safe_div(TP, (TP + FP)) |
| recall = _safe_div(TP, (TP + FN)) |
| f1 = _safe_div(2 * precision * recall, (precision + recall)) |
| iou = _safe_div(TP, (TP + FP + FN)) |
| oa = _safe_div(TP + TN, (TP + TN + FP + FN)) |
|
|
| return { |
| "TP": TP, "FP": FP, "TN": TN, "FN": FN, |
| "OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou |
| } |
| |
|
|
| def get_args(): |
| parser = argparse.ArgumentParser('description=Change detection of remote sensing images') |
| parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py") |
| parser.add_argument("--ckpt", type=str, default=None) |
| parser.add_argument("--output_dir", type=str, default=None) |
| |
| parser.add_argument("--tables-only", action="store_true", |
| help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片") |
| return parser.parse_args() |
|
|
| if __name__ == "__main__": |
| args = get_args() |
| cfg = Config.fromfile(args.config) |
|
|
| ckpt = args.ckpt |
| if ckpt is None: |
| ckpt = cfg.test_ckpt_path |
| assert ckpt is not None |
|
|
| if args.output_dir: |
| base_dir = args.output_dir |
| else: |
| base_dir = os.path.dirname(ckpt) |
|
|
| |
| masks_output_dir = os.path.join(base_dir, "mask_rgb") |
| |
| tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb") |
| os.makedirs(tables_output_dir, exist_ok=True) |
|
|
| model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg) |
| model = model.to('cuda') |
| model.eval() |
|
|
| metric_cfg_1 = cfg.metric_cfg1 |
| metric_cfg_2 = cfg.metric_cfg2 |
|
|
| test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda') |
| test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda') |
| test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda') |
| test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda') |
| test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda') |
|
|
| results = [] |
| per_image_rows = [] |
|
|
| with torch.no_grad(): |
| test_loader = build_dataloader(cfg.dataset_config, mode='test') |
| |
| pr_init(nbins=1000) |
|
|
| for input in tqdm(test_loader): |
| raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3] |
| |
| pr_update_from_outputs(raw_predictions, mask, cfg) |
|
|
| if cfg.net == 'SARASNet': |
| mask = Variable(resize_label(mask.data.cpu().numpy(), \ |
| size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long() |
| param = 1 |
| raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param |
|
|
| if cfg.argmax: |
| pred = raw_predictions.argmax(dim=1) |
| else: |
| if cfg.net == 'maskcd': |
| pred = raw_predictions[0] |
| pred = pred > 0.5 |
| pred.squeeze_(1) |
| else: |
| pred = raw_predictions.squeeze(1) |
| pred = pred > 0.5 |
|
|
| |
| test_oa(pred, mask) |
| test_iou(pred, mask) |
| test_prec(pred, mask) |
| test_f1(pred, mask) |
| test_recall(pred, mask) |
|
|
| |
| for i in range(raw_predictions.shape[0]): |
| mask_real = mask[i].detach().cpu().numpy() |
| mask_pred = pred[i].detach().cpu().numpy() |
| mask_name = str(img_id[i]) |
|
|
| |
| stats = per_image_stats(mask_pred, mask_real) |
| per_image_rows.append({ |
| "img_id": mask_name, |
| "TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"], |
| "OA": stats["OA"], "Precision": stats["Precision"], |
| "Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"] |
| }) |
|
|
| |
| if not args.tables_only: |
| results.append((mask_real, mask_pred, masks_output_dir, mask_name)) |
|
|
| |
| metrics = [test_prec.compute(), |
| test_recall.compute(), |
| test_f1.compute(), |
| test_iou.compute()] |
|
|
| total_metrics = [test_oa.compute().cpu().numpy(), |
| np.mean([item.cpu() for item in metrics[0]]), |
| np.mean([item.cpu() for item in metrics[1]]), |
| np.mean([item.cpu() for item in metrics[2]]), |
| np.mean([item.cpu() for item in metrics[3]])] |
|
|
| result_table = prettytable.PrettyTable() |
| result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] |
|
|
| for i in range(2): |
| item = [i, '--'] |
| for j in range(len(metrics)): |
| item.append(np.round(metrics[j][i].cpu().numpy(), 4)) |
| result_table.add_row(item) |
|
|
| total = [np.round(v, 4) for v in total_metrics] |
| total.insert(0, 'total') |
| result_table.add_row(total) |
| print(result_table) |
|
|
| file_name = os.path.join(base_dir, "test_res.txt") |
| f = open(file_name,"a") |
| current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time())) |
| f.write(current_time+'\n') |
| f.write(str(result_table)+'\n') |
|
|
| |
| if not args.tables_only: |
| if not os.path.exists(masks_output_dir): |
| os.makedirs(masks_output_dir) |
| print(masks_output_dir) |
|
|
| |
| t0 = time.time() |
| mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) |
| t1 = time.time() |
| img_write_time = t1 - t0 |
| print('images writing spends: {} s'.format(img_write_time)) |
| else: |
| print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。") |
|
|
| |
| per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv") |
| with open(per_image_csv, "w", newline="") as wf: |
| writer = csv.DictWriter( |
| wf, |
| fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"] |
| ) |
| writer.writeheader() |
| for row in per_image_rows: |
| row_out = dict(row) |
| for k in ["OA","Precision","Recall","F1","IoU"]: |
| row_out[k] = float(np.round(row_out[k], 6)) |
| writer.writerow(row_out) |
| print(f"[Per-Image] saved CSV: {per_image_csv}") |
|
|
| |
| for row in per_image_rows: |
| txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt") |
| pt = prettytable.PrettyTable() |
| pt.field_names = ["Metric", "Value"] |
| |
| pt.add_row(["TP", row["TP"]]) |
| pt.add_row(["FP", row["FP"]]) |
| pt.add_row(["TN", row["TN"]]) |
| pt.add_row(["FN", row["FN"]]) |
| |
| pt.add_row(["OA", f"{row['OA']:.6f}"]) |
| pt.add_row(["Precision",f"{row['Precision']:.6f}"]) |
| pt.add_row(["Recall", f"{row['Recall']:.6f}"]) |
| pt.add_row(["F1", f"{row['F1']:.6f}"]) |
| pt.add_row(["IoU", f"{row['IoU']:.6f}"]) |
| with open(txt_path, "w") as wf: |
| wf.write(str(pt)) |
| print(f"[Per-Image] per-image tables saved to: {tables_output_dir}") |
|
|
| |
| try: |
| pr_export(base_dir, cfg) |
| except Exception as e: |
| print(f"[PR] export skipped or failed: {e}") |
|
|