InPeerReview's picture
Upload 9 files
032c113 verified
raw
history blame
13.6 kB
import torch
from torch import nn
from tqdm import tqdm
import prettytable
import time
import os
import multiprocessing.pool as mpp
import multiprocessing as mp
from train import *
import argparse
from utils.config import Config
from tools.mask_convert import mask_save
import numpy as np # [PR] for histogram-based PR accumulation
import csv
# =========================== [PR] Utilities BEGIN ===========================
class PRHistogram:
# Memory-friendly PR accumulator. Call update(probs, mask) repeatedly inside
# your test loop, then call export_csv(path) after the loop.
# - probs: torch.Tensor in [0,1], shape [B,H,W], "change" probability
# - mask: torch.Tensor of 0/1 (or 0/255), shape [B,H,W]
def __init__(self, nbins: int = 1000):
import numpy as _np
self.nbins = int(nbins)
self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64)
self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64)
self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1)
def update(self, probs, mask):
import numpy as _np
p = probs.detach().float().cpu().numpy().ravel()
g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8)
pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges)
neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges)
self.pos_hist += pos_counts
self.neg_hist += neg_counts
def compute_curve(self):
import numpy as _np
# 累加得到从高阈值到低阈值的 TP/FP
pos_cum = _np.cumsum(self.pos_hist[::-1])
neg_cum = _np.cumsum(self.neg_hist[::-1])
TP = pos_cum
FP = neg_cum
FN = self.pos_hist.sum() - TP
TN = None # 曲线里用不到 TN
denom_prec = _np.maximum(TP + FP, 1)
denom_rec = _np.maximum(TP + FN, 1)
precision = TP / denom_prec
recall = TP / denom_rec
# F1 = 2PR/(P+R)
denom_f1 = _np.maximum(precision + recall, 1e-12)
f1 = 2.0 * precision * recall / denom_f1
# IoU = TP / (TP + FP + FN)
denom_iou = _np.maximum(TP + FP + FN, 1)
iou = TP / denom_iou
thresholds = self.bin_edges[::-1][1:] # 与上述累积方向一致的阈值序列
return thresholds, precision, recall, f1, iou, TP, FP, FN
def export_csv(self, save_path: str):
thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve()
import numpy as _np, os as _os
_os.makedirs(_os.path.dirname(save_path), exist_ok=True)
_np.savetxt(
save_path,
_np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]),
delimiter=",",
header="threshold,precision,recall,f1,iou,TP,FP,FN",
comments=""
)
return save_path
# Global PR object (create when needed)
_PR = None
def pr_init(nbins: int = 1000):
global _PR
if _PR is None:
_PR = PRHistogram(nbins=nbins)
return _PR
def pr_update_from_outputs(raw_predictions, mask, cfg):
# Try to derive probs ∈ [0,1] from various model outputs in this repo.
# This covers:
# - cfg.argmax=True: 2-channel logits -> softmax class-1 prob
# - single-channel logits -> sigmoid
# - net == 'maskcd' (list/tuple outputs)
# Modify here if your network has a special head.
import torch
global _PR
if _PR is None:
_PR = PRHistogram(nbins=1000)
if getattr(cfg, 'argmax', False):
logits = raw_predictions
if logits.dim() == 4 and logits.size(1) >= 2:
probs = torch.softmax(logits, dim=1)[:, 1, :, :]
else:
probs = torch.sigmoid(logits.squeeze(1))
else:
if getattr(cfg, 'net', '') == 'maskcd':
if isinstance(raw_predictions, (list, tuple)):
logits = raw_predictions[0]
else:
logits = raw_predictions
probs = torch.sigmoid(logits).squeeze(1)
else:
logits = raw_predictions
if logits.dim() == 4 and logits.size(1) == 1:
logits = logits.squeeze(1)
probs = torch.sigmoid(logits)
if mask.dim() == 4 and mask.size(1) == 1:
mask_ = mask.squeeze(1)
else:
mask_ = mask
_PR.update(probs, (mask_ > 0).to(probs.dtype))
def pr_export(base_dir: str, cfg):
# Export PR CSV to base_dir/pr_<net>.csv
import os
global _PR
if _PR is None:
return None
save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv")
out = _PR.export_csv(save_path)
print(f"[PR] saved: {out}")
return out
# ============================ [PR] Utilities END ============================
# -------------------- [Per-Image] 逐图指标工具 --------------------
def _safe_div(a, b, eps=1e-12):
return a / max(b, eps)
def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray):
"""
pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W]
返回: dict 包含 TP/FP/TN/FN 与各类指标
"""
pred_bin = (pred_np > 0).astype(np.uint8)
gt_bin = (gt_np > 0).astype(np.uint8)
TP = int(((pred_bin == 1) & (gt_bin == 1)).sum())
FP = int(((pred_bin == 1) & (gt_bin == 0)).sum())
TN = int(((pred_bin == 0) & (gt_bin == 0)).sum())
FN = int(((pred_bin == 0) & (gt_bin == 1)).sum())
precision = _safe_div(TP, (TP + FP))
recall = _safe_div(TP, (TP + FN))
f1 = _safe_div(2 * precision * recall, (precision + recall))
iou = _safe_div(TP, (TP + FP + FN))
oa = _safe_div(TP + TN, (TP + TN + FP + FN))
return {
"TP": TP, "FP": FP, "TN": TN, "FN": FN,
"OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou
}
# --------------------------------------------------------------------
def get_args():
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
# 新增:仅生成表格模式(不导出可视化图片)
parser.add_argument("--tables-only", action="store_true",
help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
cfg = Config.fromfile(args.config)
ckpt = args.ckpt
if ckpt is None:
ckpt = cfg.test_ckpt_path
assert ckpt is not None
if args.output_dir:
base_dir = args.output_dir
else:
base_dir = os.path.dirname(ckpt)
# 原图像输出目录(仅在需要写图时使用)
masks_output_dir = os.path.join(base_dir, "mask_rgb")
# 表格输出目录(逐图表格 .txt),如果 tables-only 则单独放在 tables_only 下
tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb")
os.makedirs(tables_output_dir, exist_ok=True)
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
model = model.to('cuda')
model.eval()
metric_cfg_1 = cfg.metric_cfg1
metric_cfg_2 = cfg.metric_cfg2
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
results = [] # 仅在生成图片时使用
per_image_rows = [] # [Per-Image] 收集逐图指标
with torch.no_grad():
test_loader = build_dataloader(cfg.dataset_config, mode='test')
# === 调用1: 初始化 ===
pr_init(nbins=1000)
for input in tqdm(test_loader):
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
# === 调用2: 更新 ===
pr_update_from_outputs(raw_predictions, mask, cfg)
if cfg.net == 'SARASNet':
mask = Variable(resize_label(mask.data.cpu().numpy(), \
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
if cfg.argmax:
pred = raw_predictions.argmax(dim=1)
else:
if cfg.net == 'maskcd':
pred = raw_predictions[0]
pred = pred > 0.5
pred.squeeze_(1)
else:
pred = raw_predictions.squeeze(1)
pred = pred > 0.5
# ====== 累计整体验证指标 ======
test_oa(pred, mask)
test_iou(pred, mask)
test_prec(pred, mask)
test_f1(pred, mask)
test_recall(pred, mask)
# ====== [Per-Image] 逐图指标计算与收集 ======
for i in range(raw_predictions.shape[0]):
mask_real = mask[i].detach().cpu().numpy()
mask_pred = pred[i].detach().cpu().numpy()
mask_name = str(img_id[i])
# 逐图统计
stats = per_image_stats(mask_pred, mask_real)
per_image_rows.append({
"img_id": mask_name,
"TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"],
"OA": stats["OA"], "Precision": stats["Precision"],
"Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"]
})
# 仅在需要生成可视化图片时才收集写图任务
if not args.tables_only:
results.append((mask_real, mask_pred, masks_output_dir, mask_name))
# ====== 打印总体指标 ======
metrics = [test_prec.compute(),
test_recall.compute(),
test_f1.compute(),
test_iou.compute()]
total_metrics = [test_oa.compute().cpu().numpy(),
np.mean([item.cpu() for item in metrics[0]]),
np.mean([item.cpu() for item in metrics[1]]),
np.mean([item.cpu() for item in metrics[2]]),
np.mean([item.cpu() for item in metrics[3]])]
result_table = prettytable.PrettyTable()
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
for i in range(2):
item = [i, '--']
for j in range(len(metrics)):
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
result_table.add_row(item)
total = [np.round(v, 4) for v in total_metrics]
total.insert(0, 'total')
result_table.add_row(total)
print(result_table)
file_name = os.path.join(base_dir, "test_res.txt")
f = open(file_name,"a")
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
f.write(current_time+'\n')
f.write(str(result_table)+'\n')
# ====== 根据模式选择是否写图 ======
if not args.tables_only:
if not os.path.exists(masks_output_dir):
os.makedirs(masks_output_dir)
print(masks_output_dir)
# 多进程写图
t0 = time.time()
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
t1 = time.time()
img_write_time = t1 - t0
print('images writing spends: {} s'.format(img_write_time))
else:
print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。")
# ====== [Per-Image] 将逐图指标写成一个总 CSV ======
per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv")
with open(per_image_csv, "w", newline="") as wf:
writer = csv.DictWriter(
wf,
fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"]
)
writer.writeheader()
for row in per_image_rows:
row_out = dict(row)
for k in ["OA","Precision","Recall","F1","IoU"]:
row_out[k] = float(np.round(row_out[k], 6))
writer.writerow(row_out)
print(f"[Per-Image] saved CSV: {per_image_csv}")
# ====== [Per-Image] 为每张图各自写一个小表(.txt) ======
for row in per_image_rows:
txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt")
pt = prettytable.PrettyTable()
pt.field_names = ["Metric", "Value"]
# 先放混淆矩阵元素
pt.add_row(["TP", row["TP"]])
pt.add_row(["FP", row["FP"]])
pt.add_row(["TN", row["TN"]])
pt.add_row(["FN", row["FN"]])
# 再放比率类指标
pt.add_row(["OA", f"{row['OA']:.6f}"])
pt.add_row(["Precision",f"{row['Precision']:.6f}"])
pt.add_row(["Recall", f"{row['Recall']:.6f}"])
pt.add_row(["F1", f"{row['F1']:.6f}"])
pt.add_row(["IoU", f"{row['IoU']:.6f}"])
with open(txt_path, "w") as wf:
wf.write(str(pt))
print(f"[Per-Image] per-image tables saved to: {tables_output_dir}")
# ===== [PR] Export at program end =====
try:
pr_export(base_dir, cfg)
except Exception as e:
print(f"[PR] export skipped or failed: {e}")