File size: 2,165 Bytes
4f763cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import argparse
import torch

from resselt import load_from_file
from pepeline import read, save, ImgColor, ImgFormat


def parse_args():
    parser = argparse.ArgumentParser(
        description="Batch image upscaling script"
    )
    parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save results")
    parser.add_argument("--weights", type=str, required=True, help="Path to model weights")
    parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
    return parser.parse_args()


def load_model(weights_path: str, device: torch.device):
    model = load_from_file(weights_path)
    model = model.to(
        device,
        memory_format=torch.preserve_format,
        non_blocking=True,
    ).eval()
    return model


def process_image(model, img_path: str, device: torch.device):
    img = read(img_path, ImgColor.RGB, ImgFormat.F32).transpose(2, 0, 1)
    img = (
        torch.tensor(img)
        .to(
            device,
            memory_format=torch.preserve_format,
            non_blocking=True,
        )
        .unsqueeze(0)
    )

    with torch.autocast(device.type, torch.float16):
        with torch.inference_mode():
            output = model(img)

    output = output.permute(0, 2, 3, 1).detach().cpu().numpy()[0]
    return output


def main():
    args = parse_args()

    device = torch.device(
        args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
    )

    os.makedirs(args.output_dir, exist_ok=True)

    model = load_model(args.weights, device)

    img_list = os.listdir(args.input_dir)
    total = len(img_list)

    for index, img_name in enumerate(img_list, start=1):
        print(
            f"\rProcessing {index}/{total} | {img_name}",
            end="",
            flush=True,
        )

        img_path = os.path.join(args.input_dir, img_name)
        result = process_image(model, img_path, device)

        save(result.copy(), os.path.join(args.output_dir, img_name))

    print("\nDone.")


if __name__ == "__main__":
    main()