| import argparse |
| import os |
| import sys |
| import glob |
| import time |
| from pathlib import Path |
| from PIL import Image |
| import torch |
| import torchvision.transforms as T |
|
|
| |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="TorchScript Pipeline Inference for Watermark Removal") |
| group = parser.add_mutually_exclusive_group(required=True) |
| group.add_argument('-i', '--image', type=str, help="Path to single input watermarked image") |
| group.add_argument('-f', '--folder', type=str, help="Path to folder containing watermarked images") |
| parser.add_argument('-o', '--output_folder', type=str, default='tests', help="Output folder to save original and clean images") |
| parser.add_argument('-m', '--model_path', type=str, default='model.ts', help="Path to TorchScript pipeline model (.ts file)") |
| return parser.parse_args() |
|
|
|
|
| def calculate_output_dimensions(orig_width, orig_height, max_size): |
| """ |
| Calculate output dimensions maintaining original aspect ratio. |
| Caps at max_size (never upscale beyond processing size). |
| """ |
| |
| if orig_width <= max_size and orig_height <= max_size: |
| return (orig_width, orig_height) |
|
|
| |
| if orig_width >= orig_height: |
| output_width = max_size |
| output_height = int(orig_height * (max_size / orig_width)) |
| else: |
| output_height = max_size |
| output_width = int(orig_width * (max_size / orig_height)) |
|
|
| return (output_width, output_height) |
|
|
|
|
| def load_torchscript_model(model_path): |
| """Load TorchScript pipeline model.""" |
| device = torch.device('cuda') |
|
|
| print(f"Loading TorchScript pipeline from: {model_path}") |
| model = torch.jit.load(model_path, map_location=device) |
| model.eval() |
|
|
| return model, device |
|
|
|
|
| def process_image(img_path, model, device, output_folder=None): |
| |
| img = Image.open(img_path).convert('RGB') |
| orig_width, orig_height = img.size |
|
|
| base_name = os.path.basename(img_path) |
| print(f" [{base_name}] Original: {orig_width}x{orig_height}", end="") |
|
|
| |
| img_tensor = T.ToTensor()(img).unsqueeze(0).to(device) |
|
|
| |
| |
| with torch.no_grad(): |
| pred_t = model(img_tensor) |
|
|
| |
| _, _, pipeline_size, _ = pred_t.shape |
| print(f" β Pipeline output: {pipeline_size}x{pipeline_size}", end="") |
|
|
| |
| pred_img = T.ToPILImage()(pred_t.squeeze(0).cpu()) |
|
|
| |
| output_width, output_height = calculate_output_dimensions(orig_width, orig_height, pipeline_size) |
| pred_img = pred_img.resize((output_width, output_height), resample=Image.LANCZOS) |
| print(f" β Resized: {output_width}x{output_height}", end="") |
|
|
| output_width, output_height = pred_img.size |
| print(f" β Output: {output_width}x{output_height}") |
|
|
| |
| base_name = os.path.splitext(os.path.basename(img_path))[0] |
| clean_name = f"{base_name}-clean.webp" |
|
|
| |
| os.makedirs(output_folder, exist_ok=True) |
|
|
| |
| orig_save_path = os.path.join(output_folder, os.path.basename(img_path)) |
| img.save(orig_save_path) |
|
|
| |
| clean_path = os.path.join(output_folder, clean_name) |
| pred_img.save(clean_path, 'WEBP', quality=95) |
|
|
|
|
| def main(): |
| |
| torch.set_float32_matmul_precision('high') |
|
|
| args = parse_args() |
|
|
| |
| if not os.path.exists(args.model_path): |
| print(f"Error: TorchScript model not found: {args.model_path}") |
| return |
|
|
| print(f"TorchScript Pipeline Inference") |
| print(f"Model: {args.model_path}") |
| print() |
|
|
| |
| model, device = load_torchscript_model(args.model_path) |
| print(f"Pipeline loaded on {device}") |
| print() |
|
|
| num_images = 0 |
|
|
| |
| if args.image: |
| |
| output_path = args.output_folder |
|
|
| |
| start_time = time.time() |
|
|
| process_image(args.image, model, device, output_path) |
| num_images = 1 |
| elif args.folder: |
| |
| model_name = os.path.splitext(os.path.basename(args.model_path))[0] |
| folder_name = os.path.basename(os.path.normpath(args.folder)) |
| subfolder_name = f"{model_name}_{folder_name}_ts" |
| output_path = os.path.join(args.output_folder, subfolder_name) |
|
|
| print(f"Saving outputs to: {output_path}") |
| print() |
|
|
| |
| patterns = ['*.jpg', '*.webp'] |
| images = [] |
| for pattern in patterns: |
| images.extend(glob.glob(os.path.join(args.folder, pattern))) |
|
|
| num_images = len(images) |
|
|
| |
| start_time = time.time() |
|
|
| for img_path in sorted(images): |
| process_image(img_path, model, device, output_path) |
|
|
| |
| elapsed_time = time.time() - start_time |
| print(f"\nProcessed {num_images} image{'s' if num_images != 1 else ''} in {elapsed_time:.2f} seconds ({elapsed_time/num_images:.2f}s per image)") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|