| | import torch |
| | from torch import nn |
| | from tqdm import tqdm |
| | import prettytable |
| | import time |
| | import os |
| | import multiprocessing.pool as mpp |
| | import multiprocessing as mp |
| |
|
| | from train import * |
| |
|
| | import argparse |
| | from utils.config import Config |
| | from tools.mask_convert import mask_save |
| | import numpy as np |
| | import csv |
| |
|
| | |
| | class PRHistogram: |
| | |
| | |
| | |
| | |
| | def __init__(self, nbins: int = 1000): |
| | import numpy as _np |
| | self.nbins = int(nbins) |
| | self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64) |
| | self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64) |
| | self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1) |
| |
|
| | def update(self, probs, mask): |
| | import numpy as _np |
| | p = probs.detach().float().cpu().numpy().ravel() |
| | g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8) |
| | pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges) |
| | neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges) |
| | self.pos_hist += pos_counts |
| | self.neg_hist += neg_counts |
| |
|
| | def compute_curve(self): |
| | import numpy as _np |
| | |
| | pos_cum = _np.cumsum(self.pos_hist[::-1]) |
| | neg_cum = _np.cumsum(self.neg_hist[::-1]) |
| | TP = pos_cum |
| | FP = neg_cum |
| | FN = self.pos_hist.sum() - TP |
| | TN = None |
| |
|
| | denom_prec = _np.maximum(TP + FP, 1) |
| | denom_rec = _np.maximum(TP + FN, 1) |
| | precision = TP / denom_prec |
| | recall = TP / denom_rec |
| |
|
| | |
| | denom_f1 = _np.maximum(precision + recall, 1e-12) |
| | f1 = 2.0 * precision * recall / denom_f1 |
| |
|
| | |
| | denom_iou = _np.maximum(TP + FP + FN, 1) |
| | iou = TP / denom_iou |
| |
|
| | thresholds = self.bin_edges[::-1][1:] |
| | return thresholds, precision, recall, f1, iou, TP, FP, FN |
| |
|
| | def export_csv(self, save_path: str): |
| | thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve() |
| | import numpy as _np, os as _os |
| | _os.makedirs(_os.path.dirname(save_path), exist_ok=True) |
| | _np.savetxt( |
| | save_path, |
| | _np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]), |
| | delimiter=",", |
| | header="threshold,precision,recall,f1,iou,TP,FP,FN", |
| | comments="" |
| | ) |
| | return save_path |
| |
|
| | |
| | _PR = None |
| |
|
| | def pr_init(nbins: int = 1000): |
| | global _PR |
| | if _PR is None: |
| | _PR = PRHistogram(nbins=nbins) |
| | return _PR |
| |
|
| | def pr_update_from_outputs(raw_predictions, mask, cfg): |
| | |
| | |
| | |
| | |
| | |
| | |
| | import torch |
| | global _PR |
| | if _PR is None: |
| | _PR = PRHistogram(nbins=1000) |
| |
|
| | if getattr(cfg, 'argmax', False): |
| | logits = raw_predictions |
| | if logits.dim() == 4 and logits.size(1) >= 2: |
| | probs = torch.softmax(logits, dim=1)[:, 1, :, :] |
| | else: |
| | probs = torch.sigmoid(logits.squeeze(1)) |
| | else: |
| | if getattr(cfg, 'net', '') == 'maskcd': |
| | if isinstance(raw_predictions, (list, tuple)): |
| | logits = raw_predictions[0] |
| | else: |
| | logits = raw_predictions |
| | probs = torch.sigmoid(logits).squeeze(1) |
| | else: |
| | logits = raw_predictions |
| | if logits.dim() == 4 and logits.size(1) == 1: |
| | logits = logits.squeeze(1) |
| | probs = torch.sigmoid(logits) |
| |
|
| | if mask.dim() == 4 and mask.size(1) == 1: |
| | mask_ = mask.squeeze(1) |
| | else: |
| | mask_ = mask |
| | _PR.update(probs, (mask_ > 0).to(probs.dtype)) |
| |
|
| | def pr_export(base_dir: str, cfg): |
| | |
| | import os |
| | global _PR |
| | if _PR is None: |
| | return None |
| | save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv") |
| | out = _PR.export_csv(save_path) |
| | print(f"[PR] saved: {out}") |
| | return out |
| | |
| |
|
| | |
| | def _safe_div(a, b, eps=1e-12): |
| | return a / max(b, eps) |
| |
|
| | def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray): |
| | """ |
| | pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W] |
| | 返回: dict 包含 TP/FP/TN/FN 与各类指标 |
| | """ |
| | pred_bin = (pred_np > 0).astype(np.uint8) |
| | gt_bin = (gt_np > 0).astype(np.uint8) |
| |
|
| | TP = int(((pred_bin == 1) & (gt_bin == 1)).sum()) |
| | FP = int(((pred_bin == 1) & (gt_bin == 0)).sum()) |
| | TN = int(((pred_bin == 0) & (gt_bin == 0)).sum()) |
| | FN = int(((pred_bin == 0) & (gt_bin == 1)).sum()) |
| |
|
| | precision = _safe_div(TP, (TP + FP)) |
| | recall = _safe_div(TP, (TP + FN)) |
| | f1 = _safe_div(2 * precision * recall, (precision + recall)) |
| | iou = _safe_div(TP, (TP + FP + FN)) |
| | oa = _safe_div(TP + TN, (TP + TN + FP + FN)) |
| |
|
| | return { |
| | "TP": TP, "FP": FP, "TN": TN, "FN": FN, |
| | "OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou |
| | } |
| | |
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser('description=Change detection of remote sensing images') |
| | parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py") |
| | parser.add_argument("--ckpt", type=str, default=None) |
| | parser.add_argument("--output_dir", type=str, default=None) |
| | |
| | parser.add_argument("--tables-only", action="store_true", |
| | help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片") |
| | return parser.parse_args() |
| |
|
| | if __name__ == "__main__": |
| | args = get_args() |
| | cfg = Config.fromfile(args.config) |
| |
|
| | ckpt = args.ckpt |
| | if ckpt is None: |
| | ckpt = cfg.test_ckpt_path |
| | assert ckpt is not None |
| |
|
| | if args.output_dir: |
| | base_dir = args.output_dir |
| | else: |
| | base_dir = os.path.dirname(ckpt) |
| |
|
| | |
| | masks_output_dir = os.path.join(base_dir, "mask_rgb") |
| | |
| | tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb") |
| | os.makedirs(tables_output_dir, exist_ok=True) |
| |
|
| | model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg) |
| | model = model.to('cuda') |
| | model.eval() |
| |
|
| | metric_cfg_1 = cfg.metric_cfg1 |
| | metric_cfg_2 = cfg.metric_cfg2 |
| |
|
| | test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda') |
| | test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda') |
| | test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda') |
| | test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda') |
| | test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda') |
| |
|
| | results = [] |
| | per_image_rows = [] |
| |
|
| | with torch.no_grad(): |
| | test_loader = build_dataloader(cfg.dataset_config, mode='test') |
| | |
| | pr_init(nbins=1000) |
| |
|
| | for input in tqdm(test_loader): |
| | raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3] |
| | |
| | pr_update_from_outputs(raw_predictions, mask, cfg) |
| |
|
| | if cfg.net == 'SARASNet': |
| | mask = Variable(resize_label(mask.data.cpu().numpy(), \ |
| | size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long() |
| | param = 1 |
| | raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param |
| |
|
| | if cfg.argmax: |
| | pred = raw_predictions.argmax(dim=1) |
| | else: |
| | if cfg.net == 'maskcd': |
| | pred = raw_predictions[0] |
| | pred = pred > 0.5 |
| | pred.squeeze_(1) |
| | else: |
| | pred = raw_predictions.squeeze(1) |
| | pred = pred > 0.5 |
| |
|
| | |
| | test_oa(pred, mask) |
| | test_iou(pred, mask) |
| | test_prec(pred, mask) |
| | test_f1(pred, mask) |
| | test_recall(pred, mask) |
| |
|
| | |
| | for i in range(raw_predictions.shape[0]): |
| | mask_real = mask[i].detach().cpu().numpy() |
| | mask_pred = pred[i].detach().cpu().numpy() |
| | mask_name = str(img_id[i]) |
| |
|
| | |
| | stats = per_image_stats(mask_pred, mask_real) |
| | per_image_rows.append({ |
| | "img_id": mask_name, |
| | "TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"], |
| | "OA": stats["OA"], "Precision": stats["Precision"], |
| | "Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"] |
| | }) |
| |
|
| | |
| | if not args.tables_only: |
| | results.append((mask_real, mask_pred, masks_output_dir, mask_name)) |
| |
|
| | |
| | metrics = [test_prec.compute(), |
| | test_recall.compute(), |
| | test_f1.compute(), |
| | test_iou.compute()] |
| |
|
| | total_metrics = [test_oa.compute().cpu().numpy(), |
| | np.mean([item.cpu() for item in metrics[0]]), |
| | np.mean([item.cpu() for item in metrics[1]]), |
| | np.mean([item.cpu() for item in metrics[2]]), |
| | np.mean([item.cpu() for item in metrics[3]])] |
| |
|
| | result_table = prettytable.PrettyTable() |
| | result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] |
| |
|
| | for i in range(2): |
| | item = [i, '--'] |
| | for j in range(len(metrics)): |
| | item.append(np.round(metrics[j][i].cpu().numpy(), 4)) |
| | result_table.add_row(item) |
| |
|
| | total = [np.round(v, 4) for v in total_metrics] |
| | total.insert(0, 'total') |
| | result_table.add_row(total) |
| | print(result_table) |
| |
|
| | file_name = os.path.join(base_dir, "test_res.txt") |
| | f = open(file_name,"a") |
| | current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time())) |
| | f.write(current_time+'\n') |
| | f.write(str(result_table)+'\n') |
| |
|
| | |
| | if not args.tables_only: |
| | if not os.path.exists(masks_output_dir): |
| | os.makedirs(masks_output_dir) |
| | print(masks_output_dir) |
| |
|
| | |
| | t0 = time.time() |
| | mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) |
| | t1 = time.time() |
| | img_write_time = t1 - t0 |
| | print('images writing spends: {} s'.format(img_write_time)) |
| | else: |
| | print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。") |
| |
|
| | |
| | per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv") |
| | with open(per_image_csv, "w", newline="") as wf: |
| | writer = csv.DictWriter( |
| | wf, |
| | fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"] |
| | ) |
| | writer.writeheader() |
| | for row in per_image_rows: |
| | row_out = dict(row) |
| | for k in ["OA","Precision","Recall","F1","IoU"]: |
| | row_out[k] = float(np.round(row_out[k], 6)) |
| | writer.writerow(row_out) |
| | print(f"[Per-Image] saved CSV: {per_image_csv}") |
| |
|
| | |
| | for row in per_image_rows: |
| | txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt") |
| | pt = prettytable.PrettyTable() |
| | pt.field_names = ["Metric", "Value"] |
| | |
| | pt.add_row(["TP", row["TP"]]) |
| | pt.add_row(["FP", row["FP"]]) |
| | pt.add_row(["TN", row["TN"]]) |
| | pt.add_row(["FN", row["FN"]]) |
| | |
| | pt.add_row(["OA", f"{row['OA']:.6f}"]) |
| | pt.add_row(["Precision",f"{row['Precision']:.6f}"]) |
| | pt.add_row(["Recall", f"{row['Recall']:.6f}"]) |
| | pt.add_row(["F1", f"{row['F1']:.6f}"]) |
| | pt.add_row(["IoU", f"{row['IoU']:.6f}"]) |
| | with open(txt_path, "w") as wf: |
| | wf.write(str(pt)) |
| | print(f"[Per-Image] per-image tables saved to: {tables_output_dir}") |
| |
|
| | |
| | try: |
| | pr_export(base_dir, cfg) |
| | except Exception as e: |
| | print(f"[PR] export skipped or failed: {e}") |
| |
|