remove-watermarks-fast / inference.py
jasonengage's picture
Add inference script, model, and project setup
b5e53f5
import argparse
import os
import sys
import glob
import time
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms as T
# Output resolution is capped at 768px
def parse_args():
parser = argparse.ArgumentParser(description="TorchScript Pipeline Inference for Watermark Removal")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-i', '--image', type=str, help="Path to single input watermarked image")
group.add_argument('-f', '--folder', type=str, help="Path to folder containing watermarked images")
parser.add_argument('-o', '--output_folder', type=str, default='tests', help="Output folder to save original and clean images")
parser.add_argument('-m', '--model_path', type=str, default='model.ts', help="Path to TorchScript pipeline model (.ts file)")
return parser.parse_args()
def calculate_output_dimensions(orig_width, orig_height, max_size):
"""
Calculate output dimensions maintaining original aspect ratio.
Caps at max_size (never upscale beyond processing size).
"""
# If image fits within max_size, keep original dimensions
if orig_width <= max_size and orig_height <= max_size:
return (orig_width, orig_height)
# Scale down to fit within max_size, maintaining aspect ratio
if orig_width >= orig_height:
output_width = max_size
output_height = int(orig_height * (max_size / orig_width))
else:
output_height = max_size
output_width = int(orig_width * (max_size / orig_height))
return (output_width, output_height)
def load_torchscript_model(model_path):
"""Load TorchScript pipeline model."""
device = torch.device('cuda')
print(f"Loading TorchScript pipeline from: {model_path}")
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model, device
def process_image(img_path, model, device, output_folder=None):
# Load image and get original size
img = Image.open(img_path).convert('RGB')
orig_width, orig_height = img.size
base_name = os.path.basename(img_path)
print(f" [{base_name}] Original: {orig_width}x{orig_height}", end="")
# Convert to tensor [1, 3, H, W] in [0, 1] range
img_tensor = T.ToTensor()(img).unsqueeze(0).to(device)
# Inference with TorchScript pipeline
# Pipeline handles: resize β†’ normalize β†’ model1 β†’ model2 β†’ denormalize β†’ final resize
with torch.no_grad():
pred_t = model(img_tensor) # Output: [1, 3, final_size, final_size] in [0, 1]
# Get output size from pipeline
_, _, pipeline_size, _ = pred_t.shape
print(f" β†’ Pipeline output: {pipeline_size}x{pipeline_size}", end="")
# Convert tensor to PIL (square output at pipeline_size)
pred_img = T.ToPILImage()(pred_t.squeeze(0).cpu())
# Resize back to original dimensions using PIL LANCZOS (capped at pipeline_size)
output_width, output_height = calculate_output_dimensions(orig_width, orig_height, pipeline_size)
pred_img = pred_img.resize((output_width, output_height), resample=Image.LANCZOS)
print(f" β†’ Resized: {output_width}x{output_height}", end="")
output_width, output_height = pred_img.size
print(f" β†’ Output: {output_width}x{output_height}")
# Determine save paths
base_name = os.path.splitext(os.path.basename(img_path))[0]
clean_name = f"{base_name}-clean.webp"
# Create output folder and save both original and clean versions
os.makedirs(output_folder, exist_ok=True)
# Save original in output folder (keeps original extension)
orig_save_path = os.path.join(output_folder, os.path.basename(img_path))
img.save(orig_save_path)
# Save clean version (webp format with -clean suffix)
clean_path = os.path.join(output_folder, clean_name)
pred_img.save(clean_path, 'WEBP', quality=95)
def main():
# Enable TensorFloat32 for faster matmul on Ampere+ GPUs
torch.set_float32_matmul_precision('high')
args = parse_args()
# Verify TorchScript model exists
if not os.path.exists(args.model_path):
print(f"Error: TorchScript model not found: {args.model_path}")
return
print(f"TorchScript Pipeline Inference")
print(f"Model: {args.model_path}")
print()
# Load TorchScript pipeline once
model, device = load_torchscript_model(args.model_path)
print(f"Pipeline loaded on {device}")
print()
num_images = 0
# Determine output folder based on processing mode
if args.image:
# Single image: save directly in output_folder
output_path = args.output_folder
# Start timing AFTER model loading
start_time = time.time()
process_image(args.image, model, device, output_path)
num_images = 1
elif args.folder:
# Folder processing: create subfolder {model_name}_{folder_name}_ts
model_name = os.path.splitext(os.path.basename(args.model_path))[0]
folder_name = os.path.basename(os.path.normpath(args.folder))
subfolder_name = f"{model_name}_{folder_name}_ts"
output_path = os.path.join(args.output_folder, subfolder_name)
print(f"Saving outputs to: {output_path}")
print()
# Process all JPG/WebP in folder
patterns = ['*.jpg', '*.webp']
images = []
for pattern in patterns:
images.extend(glob.glob(os.path.join(args.folder, pattern)))
num_images = len(images)
# Start timing AFTER model loading
start_time = time.time()
for img_path in sorted(images):
process_image(img_path, model, device, output_path)
# Print total processing time
elapsed_time = time.time() - start_time
print(f"\nProcessed {num_images} image{'s' if num_images != 1 else ''} in {elapsed_time:.2f} seconds ({elapsed_time/num_images:.2f}s per image)")
if __name__ == '__main__':
main()