File size: 3,099 Bytes
7b95dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# test_miou.py (测模型mIoU on S3DIS Area_1 val)

import torch
from pointcept.engines.defaults import default_config_parser, default_setup
from pointcept.models import build_model
from pointcept.datasets import build_dataset, point_collate_fn
from pointcept.utils.config import DictAction
from pointcept.models.loss import build_criteria
from pointcept.utils.metrics import build_metrics
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os

def test_miou(config_file, weight_path, options=None):
    print("="*40)
    print(f"Starting mIoU Test for: {config_file.split('/')[-1]}")
    print(f"Using weights: {weight_path}")
    print("="*40)

    cfg = default_config_parser(config_file, options)
    cfg = default_setup(cfg)
    
    # 构建模型
    model = build_model(cfg.model)
    
    # 量化转换如果需
    if cfg.get("quantize", False):
        print("INFO: Quantization flag detected. Converting model to Bi-PTV3...")
        from pointcept.models.quantization.quant_utils import convert_ptv3_to_bi_ptv3
        model = convert_ptv3_to_bi_ptv3(model)
    
    # 加载权重
    print(f"INFO: Loading weights from {weight_path}...")
    weight = torch.load(weight_path, map_location="cpu", weights_only=True)
    if "state_dict" in weight:
        weight = weight["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in weight.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict, strict=False)
    
    model = model.cuda()
    model.eval()

    # 构建val dataset
    dataloader = torch.utils.data.DataLoader(
        build_dataset(cfg.data.val),
        batch_size=1,
        shuffle=False,
        collate_fn=point_collate_fn,
        num_workers=0
    )

    # 构建criteria和metrics
    criteria = build_criteria(cfg.criteria)
    metrics = build_metrics(cfg.metrics)

    # 测试loop
    total_loss = 0
    num_batches = 0
    for batch in dataloader:
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].cuda()
        with torch.no_grad():
            output = model(batch)
            loss = criteria(output, batch['segment'])
            metrics.update(output, batch['segment'])
        total_loss += loss.item()
        num_batches += 1
        print(f"DEBUG: Batch {num_batches}, Loss: {loss.item()}")

    avg_loss = total_loss / num_batches
    mIoU = metrics.compute()['mIoU'] * 100  # 百分比

    print("\n" + "="*40)
    print(f"🏆 Test Complete! 🏆")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"mIoU: {mIoU:.2f}%")
    print("="*40 + "\n")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-file", required=True)
    parser.add_argument("--weight", required=True)
    parser.add_argument("--options", nargs="+", action=DictAction)
    args = parser.parse_args()

    test_miou(args.config_file, args.weight, args.options)