biptv3 / code /pointcept_framework /test_miou.py
YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
3.1 kB
# test_miou.py (测模型mIoU on S3DIS Area_1 val)
import torch
from pointcept.engines.defaults import default_config_parser, default_setup
from pointcept.models import build_model
from pointcept.datasets import build_dataset, point_collate_fn
from pointcept.utils.config import DictAction
from pointcept.models.loss import build_criteria
from pointcept.utils.metrics import build_metrics
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os
def test_miou(config_file, weight_path, options=None):
print("="*40)
print(f"Starting mIoU Test for: {config_file.split('/')[-1]}")
print(f"Using weights: {weight_path}")
print("="*40)
cfg = default_config_parser(config_file, options)
cfg = default_setup(cfg)
# 构建模型
model = build_model(cfg.model)
# 量化转换如果需
if cfg.get("quantize", False):
print("INFO: Quantization flag detected. Converting model to Bi-PTV3...")
from pointcept.models.quantization.quant_utils import convert_ptv3_to_bi_ptv3
model = convert_ptv3_to_bi_ptv3(model)
# 加载权重
print(f"INFO: Loading weights from {weight_path}...")
weight = torch.load(weight_path, map_location="cpu", weights_only=True)
if "state_dict" in weight:
weight = weight["state_dict"]
new_state_dict = OrderedDict()
for k, v in weight.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model = model.cuda()
model.eval()
# 构建val dataset
dataloader = torch.utils.data.DataLoader(
build_dataset(cfg.data.val),
batch_size=1,
shuffle=False,
collate_fn=point_collate_fn,
num_workers=0
)
# 构建criteria和metrics
criteria = build_criteria(cfg.criteria)
metrics = build_metrics(cfg.metrics)
# 测试loop
total_loss = 0
num_batches = 0
for batch in dataloader:
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].cuda()
with torch.no_grad():
output = model(batch)
loss = criteria(output, batch['segment'])
metrics.update(output, batch['segment'])
total_loss += loss.item()
num_batches += 1
print(f"DEBUG: Batch {num_batches}, Loss: {loss.item()}")
avg_loss = total_loss / num_batches
mIoU = metrics.compute()['mIoU'] * 100 # 百分比
print("\n" + "="*40)
print(f"🏆 Test Complete! 🏆")
print(f"Average Loss: {avg_loss:.4f}")
print(f"mIoU: {mIoU:.2f}%")
print("="*40 + "\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", required=True)
parser.add_argument("--weight", required=True)
parser.add_argument("--options", nargs="+", action=DictAction)
args = parser.parse_args()
test_miou(args.config_file, args.weight, args.options)