File size: 6,903 Bytes
0f52c9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import torch
import cv2
import numpy as np
import math
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import argparse
from options.test_options import TestOptions
from models.pix2pix_model import Pix2PixModel
from util import util
def get_safe_resolution(width, height, max_height):
if height <= max_height and width % 16 == 0 and height % 16 == 0:
return width, height
if height > max_height:
ratio = max_height / float(height)
new_width = int(width * ratio)
new_height = max_height
else:
new_width = width
new_height = height
safe_width = new_width - (new_width % 16)
safe_height = new_height - (new_height % 16)
if safe_width == 0: safe_width = 16
if safe_height == 0: safe_height = 16
return safe_width, safe_height
class CustomTestOptions(TestOptions):
def initialize(self, parser):
parser = super().initialize(parser)
parser.add_argument('--hours', type=float, nargs='+', required=True,
help='List of hours to apply (e.g., 8.5 13 18.75). Range: 0h-24h.')
parser.add_argument('--add_timestamp', action='store_true',
help='If enabled, a timestamp (HH:MM) will be added at the top-left corner of the image.')
parser.add_argument('--max_height', type=int, default=720,
help='Maximum height for processed images. If set, images will be resized. Example: 720.')
parser.set_defaults(preprocess_mode='none')
return parser
def main():
opt = CustomTestOptions().parse()
model = Pix2PixModel(opt)
model.eval()
print(f"Model {opt.name} epoch {opt.which_epoch} has been successfully loaded.")
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
try:
image_files = sorted(
[f for f in os.listdir(opt.image_dir) if os.path.splitext(f)[1].lower() in image_extensions])
except FileNotFoundError:
raise IOError(f"Error: Cannot find input image directory: {opt.image_dir}")
if not image_files:
raise IOError(f"No image files found in directory: {opt.image_dir}")
if opt.how_many < len(image_files):
image_files = image_files[:int(opt.how_many)]
print(f"Processing the first {len(image_files)} images according to --how_many parameter.")
output_dir = os.path.join(opt.results_dir, opt.name, f'{opt.phase}_{opt.which_epoch}')
os.makedirs(output_dir, exist_ok=True)
print(f"Results will be saved to: {output_dir}")
total_iterations = len(image_files) * len(opt.hours)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
printed_resize_message = False
with tqdm(total=total_iterations, desc="Processing images") as pbar:
for image_name in image_files:
image_path = os.path.join(opt.image_dir, image_name)
try:
input_image_pil = Image.open(image_path).convert('RGB')
original_width, original_height = input_image_pil.size
image_to_process = input_image_pil
if opt.max_height is not None and opt.max_height > 0:
proc_width, proc_height = get_safe_resolution(original_width, original_height, opt.max_height)
if (proc_width, proc_height) != (original_width, original_height):
image_to_process = input_image_pil.resize((proc_width, proc_height), Image.Resampling.LANCZOS)
if not printed_resize_message:
print(
f"Note: Large images will be resized. Example: {original_width}x{original_height} -> {proc_width}x{proc_height}.")
printed_resize_message = True
input_tensor = transform(image_to_process).unsqueeze(0)
data_i = {'day': input_tensor, 'cpath': [image_path]}
for hour in opt.hours:
angle_degrees = hour * 15.0
phi_val = math.radians(angle_degrees)
phi_tensor_1d = torch.tensor([phi_val], device=model.device)
model.opt.phi = phi_tensor_1d
with torch.no_grad():
generated_tensor = model(data_i, mode='inference', arbitrary_input=True)
generated_numpy_rgb = util.tensor2im(generated_tensor)[0]
if image_to_process.size != input_image_pil.size:
generated_pil = Image.fromarray(generated_numpy_rgb)
generated_pil_resized = generated_pil.resize(input_image_pil.size, Image.Resampling.LANCZOS)
generated_numpy_rgb = np.array(generated_pil_resized)
generated_bgr = cv2.cvtColor(generated_numpy_rgb, cv2.COLOR_RGB2BGR)
if opt.add_timestamp:
h = int(hour)
m = int((hour - h) * 60)
timestamp_text = f"{h:02d}:{m:02d}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
padding = 8
(text_width, text_height), baseline = cv2.getTextSize(timestamp_text, font, font_scale,
thickness)
margin = 10
rect_tl = (margin, margin)
rect_br = (margin + text_width + padding * 2, margin + text_height + padding * 2)
text_origin = (margin + padding, margin + text_height + padding)
cv2.rectangle(generated_bgr, rect_tl, rect_br, (0, 0, 0), cv2.FILLED)
cv2.putText(generated_bgr, timestamp_text, text_origin, font, font_scale,
(255, 255, 255), thickness, cv2.LINE_AA)
base_name, ext = os.path.splitext(image_name)
output_filename = f"{base_name}_hour_{hour}{ext}"
output_path = os.path.join(output_dir, output_filename)
cv2.imwrite(output_path, generated_bgr)
pbar.update(1)
except Exception as e:
print(f"\nError while processing image {image_name}: {e}")
if 'cuda' in str(e).lower():
torch.cuda.empty_cache()
remaining_hours = len(opt.hours) - (pbar.n % len(opt.hours))
if remaining_hours < len(opt.hours):
pbar.update(remaining_hours)
print(f"\nProcessing completed. Results are saved in: {output_dir}")
if __name__ == '__main__':
main()
|