File size: 4,670 Bytes
0b69a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn.functional as F
from argparse import ArgumentParser
import os
from tqdm import tqdm

current_dir = os.path.abspath(os.path.dirname(__file__))

from datasets import NWPUTest, Resize2Multiple
from models import get_model
from utils import get_config, sliding_window_predict

parser = ArgumentParser(description="Generate the test result of a trained model on the NWPU-Crowd test set.")
# Parameters for model
parser.add_argument("--weight_path", type=str, required=True, help="The directory to the checkpoints. This should also include the model_info.pkl file.")
parser.add_argument("--output_filename", type=str, default="test_results", help="The name of the output file.")

# Parameters for evaluation
parser.add_argument("--input_size", type=int, default=224, help="The size of the input image.")
parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.")
parser.add_argument("--max_input_size", type=int, default=4096, help="The maximum size of the input image in evaluation. Images larger than this will be processed using sliding window by force to avoid OOM.")
parser.add_argument("--max_num_windows", type=int, default=8, help="The maximum number of windows to be simultaneously processed.")
parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.")
parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.")
parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision for evaluation.")
parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.")


def main(args: ArgumentParser):
    print("Testing a trained model on the NWPU-Crowd test set.")
    device = torch.device(args.device)
    _ = get_config(vars(args).copy(), mute=False)

    model = get_model(model_info_path=args.weight_path).to(device)
    model.eval()

    sliding_window = args.sliding_window
    if args.resize_to_multiple:
        transforms = Resize2Multiple(base=args.input_size)
    else:
        transforms = None

    dataset = NWPUTest(transforms=transforms, return_filename=True)

    image_ids = []
    preds = []
    input_size = args.input_size

    for idx in tqdm(range(len(dataset)), desc="Testing on NWPU"):
        image, image_path = dataset[idx]
        image = image.unsqueeze(0)  # add batch dimension
        image = image.to(device)  # add batch dimension
        image_height, image_width = image.shape[-2:]

        # Resize image if it's smaller than the window size
        aspect_ratio = image_width / image_height
        if image_height < input_size:
            new_height = input_size
            new_width = int(new_height * aspect_ratio)
            image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False)
            image_height, image_width = new_height, new_width
        if image_width < input_size:
            new_width = input_size
            new_height = int(new_width / aspect_ratio)
            image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False)
            image_height, image_width = new_height, new_width

        with torch.set_grad_enabled(False), torch.autocast(device_type="cuda", enabled=args.amp):
            if sliding_window or (args.max_input_size is not None and (image_height * image_width) > args.max_input_size ** 2):
                pred_den_map = sliding_window_predict(model, image, input_size, args.stride, args.max_num_windows)
            else:
                pred_den_map = model(image)

            pred_count = pred_den_map.sum(dim=(1, 2, 3)).item()

        image_ids.append(os.path.basename(image_path).split(".")[0])
        preds.append(pred_count)

    result_dir = os.path.join(current_dir, "nwpu_test_results")
    os.makedirs(result_dir, exist_ok=True)
    with open(os.path.join(result_dir, f"{args.output_filename}.txt"), "w") as f:
        for idx, (image_id, pred) in enumerate(zip(image_ids, preds)):
            if idx != len(image_ids) - 1:
                f.write(f"{image_id} {pred}\n")
            else:
                f.write(f"{image_id} {pred}")  # no newline at the end of the file


if __name__ == "__main__":
    args = parser.parse_args()
    # Sliding window prediction will be used if args.sliding_window is True, or when the image size is larger than args.max_input_size
    args.stride = args.stride or args.input_size
    assert os.path.exists(args.weight_path), f"Checkpoint directory {args.weight_path} does not exist."
    main(args)