# test_miou.py (测模型mIoU on S3DIS Area_1 val) import torch from pointcept.engines.defaults import default_config_parser, default_setup from pointcept.models import build_model from pointcept.datasets import build_dataset, point_collate_fn from pointcept.utils.config import DictAction from pointcept.models.loss import build_criteria from pointcept.utils.metrics import build_metrics from collections import OrderedDict from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist import os def test_miou(config_file, weight_path, options=None): print("="*40) print(f"Starting mIoU Test for: {config_file.split('/')[-1]}") print(f"Using weights: {weight_path}") print("="*40) cfg = default_config_parser(config_file, options) cfg = default_setup(cfg) # 构建模型 model = build_model(cfg.model) # 量化转换如果需 if cfg.get("quantize", False): print("INFO: Quantization flag detected. Converting model to Bi-PTV3...") from pointcept.models.quantization.quant_utils import convert_ptv3_to_bi_ptv3 model = convert_ptv3_to_bi_ptv3(model) # 加载权重 print(f"INFO: Loading weights from {weight_path}...") weight = torch.load(weight_path, map_location="cpu", weights_only=True) if "state_dict" in weight: weight = weight["state_dict"] new_state_dict = OrderedDict() for k, v in weight.items(): name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=False) model = model.cuda() model.eval() # 构建val dataset dataloader = torch.utils.data.DataLoader( build_dataset(cfg.data.val), batch_size=1, shuffle=False, collate_fn=point_collate_fn, num_workers=0 ) # 构建criteria和metrics criteria = build_criteria(cfg.criteria) metrics = build_metrics(cfg.metrics) # 测试loop total_loss = 0 num_batches = 0 for batch in dataloader: for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].cuda() with torch.no_grad(): output = model(batch) loss = criteria(output, batch['segment']) metrics.update(output, batch['segment']) total_loss += loss.item() num_batches += 1 print(f"DEBUG: Batch {num_batches}, Loss: {loss.item()}") avg_loss = total_loss / num_batches mIoU = metrics.compute()['mIoU'] * 100 # 百分比 print("\n" + "="*40) print(f"🏆 Test Complete! 🏆") print(f"Average Loss: {avg_loss:.4f}") print(f"mIoU: {mIoU:.2f}%") print("="*40 + "\n") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--config-file", required=True) parser.add_argument("--weight", required=True) parser.add_argument("--options", nargs="+", action=DictAction) args = parser.parse_args() test_miou(args.config_file, args.weight, args.options)