| |
|
|
| import torch |
| from pointcept.engines.defaults import default_config_parser, default_setup |
| from pointcept.models import build_model |
| from pointcept.datasets import build_dataset, point_collate_fn |
| from pointcept.utils.config import DictAction |
| from pointcept.models.loss import build_criteria |
| from pointcept.utils.metrics import build_metrics |
| from collections import OrderedDict |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import torch.distributed as dist |
| import os |
|
|
| def test_miou(config_file, weight_path, options=None): |
| print("="*40) |
| print(f"Starting mIoU Test for: {config_file.split('/')[-1]}") |
| print(f"Using weights: {weight_path}") |
| print("="*40) |
|
|
| cfg = default_config_parser(config_file, options) |
| cfg = default_setup(cfg) |
| |
| |
| model = build_model(cfg.model) |
| |
| |
| if cfg.get("quantize", False): |
| print("INFO: Quantization flag detected. Converting model to Bi-PTV3...") |
| from pointcept.models.quantization.quant_utils import convert_ptv3_to_bi_ptv3 |
| model = convert_ptv3_to_bi_ptv3(model) |
| |
| |
| print(f"INFO: Loading weights from {weight_path}...") |
| weight = torch.load(weight_path, map_location="cpu", weights_only=True) |
| if "state_dict" in weight: |
| weight = weight["state_dict"] |
| new_state_dict = OrderedDict() |
| for k, v in weight.items(): |
| name = k[7:] if k.startswith('module.') else k |
| new_state_dict[name] = v |
| model.load_state_dict(new_state_dict, strict=False) |
| |
| model = model.cuda() |
| model.eval() |
|
|
| |
| dataloader = torch.utils.data.DataLoader( |
| build_dataset(cfg.data.val), |
| batch_size=1, |
| shuffle=False, |
| collate_fn=point_collate_fn, |
| num_workers=0 |
| ) |
|
|
| |
| criteria = build_criteria(cfg.criteria) |
| metrics = build_metrics(cfg.metrics) |
|
|
| |
| total_loss = 0 |
| num_batches = 0 |
| for batch in dataloader: |
| for key in batch: |
| if isinstance(batch[key], torch.Tensor): |
| batch[key] = batch[key].cuda() |
| with torch.no_grad(): |
| output = model(batch) |
| loss = criteria(output, batch['segment']) |
| metrics.update(output, batch['segment']) |
| total_loss += loss.item() |
| num_batches += 1 |
| print(f"DEBUG: Batch {num_batches}, Loss: {loss.item()}") |
|
|
| avg_loss = total_loss / num_batches |
| mIoU = metrics.compute()['mIoU'] * 100 |
|
|
| print("\n" + "="*40) |
| print(f"🏆 Test Complete! 🏆") |
| print(f"Average Loss: {avg_loss:.4f}") |
| print(f"mIoU: {mIoU:.2f}%") |
| print("="*40 + "\n") |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config-file", required=True) |
| parser.add_argument("--weight", required=True) |
| parser.add_argument("--options", nargs="+", action=DictAction) |
| args = parser.parse_args() |
|
|
| test_miou(args.config_file, args.weight, args.options) |