CR-Net / test.py
datnguyentien204's picture
Upload 147 files
0f52c9d verified
import os
import torch
import cv2
import numpy as np
import math
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import argparse
from options.test_options import TestOptions
from models.pix2pix_model import Pix2PixModel
from util import util
def get_safe_resolution(width, height, max_height):
if height <= max_height and width % 16 == 0 and height % 16 == 0:
return width, height
if height > max_height:
ratio = max_height / float(height)
new_width = int(width * ratio)
new_height = max_height
else:
new_width = width
new_height = height
safe_width = new_width - (new_width % 16)
safe_height = new_height - (new_height % 16)
if safe_width == 0: safe_width = 16
if safe_height == 0: safe_height = 16
return safe_width, safe_height
class CustomTestOptions(TestOptions):
def initialize(self, parser):
parser = super().initialize(parser)
parser.add_argument('--hours', type=float, nargs='+', required=True,
help='List of hours to apply (e.g., 8.5 13 18.75). Range: 0h-24h.')
parser.add_argument('--add_timestamp', action='store_true',
help='If enabled, a timestamp (HH:MM) will be added at the top-left corner of the image.')
parser.add_argument('--max_height', type=int, default=720,
help='Maximum height for processed images. If set, images will be resized. Example: 720.')
parser.set_defaults(preprocess_mode='none')
return parser
def main():
opt = CustomTestOptions().parse()
model = Pix2PixModel(opt)
model.eval()
print(f"Model {opt.name} epoch {opt.which_epoch} has been successfully loaded.")
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
try:
image_files = sorted(
[f for f in os.listdir(opt.image_dir) if os.path.splitext(f)[1].lower() in image_extensions])
except FileNotFoundError:
raise IOError(f"Error: Cannot find input image directory: {opt.image_dir}")
if not image_files:
raise IOError(f"No image files found in directory: {opt.image_dir}")
if opt.how_many < len(image_files):
image_files = image_files[:int(opt.how_many)]
print(f"Processing the first {len(image_files)} images according to --how_many parameter.")
output_dir = os.path.join(opt.results_dir, opt.name, f'{opt.phase}_{opt.which_epoch}')
os.makedirs(output_dir, exist_ok=True)
print(f"Results will be saved to: {output_dir}")
total_iterations = len(image_files) * len(opt.hours)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
printed_resize_message = False
with tqdm(total=total_iterations, desc="Processing images") as pbar:
for image_name in image_files:
image_path = os.path.join(opt.image_dir, image_name)
try:
input_image_pil = Image.open(image_path).convert('RGB')
original_width, original_height = input_image_pil.size
image_to_process = input_image_pil
if opt.max_height is not None and opt.max_height > 0:
proc_width, proc_height = get_safe_resolution(original_width, original_height, opt.max_height)
if (proc_width, proc_height) != (original_width, original_height):
image_to_process = input_image_pil.resize((proc_width, proc_height), Image.Resampling.LANCZOS)
if not printed_resize_message:
print(
f"Note: Large images will be resized. Example: {original_width}x{original_height} -> {proc_width}x{proc_height}.")
printed_resize_message = True
input_tensor = transform(image_to_process).unsqueeze(0)
data_i = {'day': input_tensor, 'cpath': [image_path]}
for hour in opt.hours:
angle_degrees = hour * 15.0
phi_val = math.radians(angle_degrees)
phi_tensor_1d = torch.tensor([phi_val], device=model.device)
model.opt.phi = phi_tensor_1d
with torch.no_grad():
generated_tensor = model(data_i, mode='inference', arbitrary_input=True)
generated_numpy_rgb = util.tensor2im(generated_tensor)[0]
if image_to_process.size != input_image_pil.size:
generated_pil = Image.fromarray(generated_numpy_rgb)
generated_pil_resized = generated_pil.resize(input_image_pil.size, Image.Resampling.LANCZOS)
generated_numpy_rgb = np.array(generated_pil_resized)
generated_bgr = cv2.cvtColor(generated_numpy_rgb, cv2.COLOR_RGB2BGR)
if opt.add_timestamp:
h = int(hour)
m = int((hour - h) * 60)
timestamp_text = f"{h:02d}:{m:02d}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
padding = 8
(text_width, text_height), baseline = cv2.getTextSize(timestamp_text, font, font_scale,
thickness)
margin = 10
rect_tl = (margin, margin)
rect_br = (margin + text_width + padding * 2, margin + text_height + padding * 2)
text_origin = (margin + padding, margin + text_height + padding)
cv2.rectangle(generated_bgr, rect_tl, rect_br, (0, 0, 0), cv2.FILLED)
cv2.putText(generated_bgr, timestamp_text, text_origin, font, font_scale,
(255, 255, 255), thickness, cv2.LINE_AA)
base_name, ext = os.path.splitext(image_name)
output_filename = f"{base_name}_hour_{hour}{ext}"
output_path = os.path.join(output_dir, output_filename)
cv2.imwrite(output_path, generated_bgr)
pbar.update(1)
except Exception as e:
print(f"\nError while processing image {image_name}: {e}")
if 'cuda' in str(e).lower():
torch.cuda.empty_cache()
remaining_hours = len(opt.hours) - (pbar.n % len(opt.hours))
if remaining_hours < len(opt.hours):
pbar.update(remaining_hours)
print(f"\nProcessing completed. Results are saved in: {output_dir}")
if __name__ == '__main__':
main()