File size: 6,903 Bytes
0f52c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import torch
import cv2
import numpy as np
import math
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import argparse
from options.test_options import TestOptions
from models.pix2pix_model import Pix2PixModel
from util import util


def get_safe_resolution(width, height, max_height):
    if height <= max_height and width % 16 == 0 and height % 16 == 0:
        return width, height

    if height > max_height:
        ratio = max_height / float(height)
        new_width = int(width * ratio)
        new_height = max_height
    else:
        new_width = width
        new_height = height

    safe_width = new_width - (new_width % 16)
    safe_height = new_height - (new_height % 16)

    if safe_width == 0: safe_width = 16
    if safe_height == 0: safe_height = 16

    return safe_width, safe_height


class CustomTestOptions(TestOptions):
    def initialize(self, parser):
        parser = super().initialize(parser)
        parser.add_argument('--hours', type=float, nargs='+', required=True,
                            help='List of hours to apply (e.g., 8.5 13 18.75). Range: 0h-24h.')
        parser.add_argument('--add_timestamp', action='store_true',
                            help='If enabled, a timestamp (HH:MM) will be added at the top-left corner of the image.')
        parser.add_argument('--max_height', type=int, default=720,
                            help='Maximum height for processed images. If set, images will be resized. Example: 720.')
        parser.set_defaults(preprocess_mode='none')
        return parser


def main():
    opt = CustomTestOptions().parse()

    model = Pix2PixModel(opt)
    model.eval()
    print(f"Model {opt.name} epoch {opt.which_epoch} has been successfully loaded.")

    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
    try:
        image_files = sorted(
            [f for f in os.listdir(opt.image_dir) if os.path.splitext(f)[1].lower() in image_extensions])
    except FileNotFoundError:
        raise IOError(f"Error: Cannot find input image directory: {opt.image_dir}")

    if not image_files:
        raise IOError(f"No image files found in directory: {opt.image_dir}")

    if opt.how_many < len(image_files):
        image_files = image_files[:int(opt.how_many)]
        print(f"Processing the first {len(image_files)} images according to --how_many parameter.")

    output_dir = os.path.join(opt.results_dir, opt.name, f'{opt.phase}_{opt.which_epoch}')
    os.makedirs(output_dir, exist_ok=True)
    print(f"Results will be saved to: {output_dir}")

    total_iterations = len(image_files) * len(opt.hours)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    printed_resize_message = False

    with tqdm(total=total_iterations, desc="Processing images") as pbar:
        for image_name in image_files:
            image_path = os.path.join(opt.image_dir, image_name)

            try:
                input_image_pil = Image.open(image_path).convert('RGB')
                original_width, original_height = input_image_pil.size
                image_to_process = input_image_pil

                if opt.max_height is not None and opt.max_height > 0:
                    proc_width, proc_height = get_safe_resolution(original_width, original_height, opt.max_height)
                    if (proc_width, proc_height) != (original_width, original_height):
                        image_to_process = input_image_pil.resize((proc_width, proc_height), Image.Resampling.LANCZOS)
                        if not printed_resize_message:
                            print(
                                f"Note: Large images will be resized. Example: {original_width}x{original_height} -> {proc_width}x{proc_height}.")
                            printed_resize_message = True

                input_tensor = transform(image_to_process).unsqueeze(0)
                data_i = {'day': input_tensor, 'cpath': [image_path]}

                for hour in opt.hours:
                    angle_degrees = hour * 15.0
                    phi_val = math.radians(angle_degrees)
                    phi_tensor_1d = torch.tensor([phi_val], device=model.device)
                    model.opt.phi = phi_tensor_1d

                    with torch.no_grad():
                        generated_tensor = model(data_i, mode='inference', arbitrary_input=True)

                    generated_numpy_rgb = util.tensor2im(generated_tensor)[0]

                    if image_to_process.size != input_image_pil.size:
                        generated_pil = Image.fromarray(generated_numpy_rgb)
                        generated_pil_resized = generated_pil.resize(input_image_pil.size, Image.Resampling.LANCZOS)
                        generated_numpy_rgb = np.array(generated_pil_resized)

                    generated_bgr = cv2.cvtColor(generated_numpy_rgb, cv2.COLOR_RGB2BGR)

                    if opt.add_timestamp:
                        h = int(hour)
                        m = int((hour - h) * 60)
                        timestamp_text = f"{h:02d}:{m:02d}"

                        font = cv2.FONT_HERSHEY_SIMPLEX
                        font_scale = 0.6
                        thickness = 1
                        padding = 8

                        (text_width, text_height), baseline = cv2.getTextSize(timestamp_text, font, font_scale,
                                                                              thickness)

                        margin = 10
                        rect_tl = (margin, margin)
                        rect_br = (margin + text_width + padding * 2, margin + text_height + padding * 2)

                        text_origin = (margin + padding, margin + text_height + padding)

                        cv2.rectangle(generated_bgr, rect_tl, rect_br, (0, 0, 0), cv2.FILLED)

                        cv2.putText(generated_bgr, timestamp_text, text_origin, font, font_scale,
                                    (255, 255, 255), thickness, cv2.LINE_AA)
                    base_name, ext = os.path.splitext(image_name)
                    output_filename = f"{base_name}_hour_{hour}{ext}"
                    output_path = os.path.join(output_dir, output_filename)
                    cv2.imwrite(output_path, generated_bgr)

                    pbar.update(1)

            except Exception as e:
                print(f"\nError while processing image {image_name}: {e}")
                if 'cuda' in str(e).lower():
                    torch.cuda.empty_cache()
                remaining_hours = len(opt.hours) - (pbar.n % len(opt.hours))
                if remaining_hours < len(opt.hours):
                    pbar.update(remaining_hours)

    print(f"\nProcessing completed. Results are saved in: {output_dir}")


if __name__ == '__main__':
    main()