File size: 5,929 Bytes
b5e53f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
import os
import sys
import glob
import time
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms as T

# Output resolution is capped at 768px


def parse_args():
    parser = argparse.ArgumentParser(description="TorchScript Pipeline Inference for Watermark Removal")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('-i', '--image', type=str, help="Path to single input watermarked image")
    group.add_argument('-f', '--folder', type=str, help="Path to folder containing watermarked images")
    parser.add_argument('-o', '--output_folder', type=str, default='tests', help="Output folder to save original and clean images")
    parser.add_argument('-m', '--model_path', type=str, default='model.ts', help="Path to TorchScript pipeline model (.ts file)")
    return parser.parse_args()


def calculate_output_dimensions(orig_width, orig_height, max_size):
    """
    Calculate output dimensions maintaining original aspect ratio.
    Caps at max_size (never upscale beyond processing size).
    """
    # If image fits within max_size, keep original dimensions
    if orig_width <= max_size and orig_height <= max_size:
        return (orig_width, orig_height)

    # Scale down to fit within max_size, maintaining aspect ratio
    if orig_width >= orig_height:
        output_width = max_size
        output_height = int(orig_height * (max_size / orig_width))
    else:
        output_height = max_size
        output_width = int(orig_width * (max_size / orig_height))

    return (output_width, output_height)


def load_torchscript_model(model_path):
    """Load TorchScript pipeline model."""
    device = torch.device('cuda')

    print(f"Loading TorchScript pipeline from: {model_path}")
    model = torch.jit.load(model_path, map_location=device)
    model.eval()

    return model, device


def process_image(img_path, model, device, output_folder=None):
    # Load image and get original size
    img = Image.open(img_path).convert('RGB')
    orig_width, orig_height = img.size

    base_name = os.path.basename(img_path)
    print(f"  [{base_name}] Original: {orig_width}x{orig_height}", end="")

    # Convert to tensor [1, 3, H, W] in [0, 1] range
    img_tensor = T.ToTensor()(img).unsqueeze(0).to(device)

    # Inference with TorchScript pipeline
    # Pipeline handles: resize → normalize → model1 → model2 → denormalize → final resize
    with torch.no_grad():
        pred_t = model(img_tensor)  # Output: [1, 3, final_size, final_size] in [0, 1]

    # Get output size from pipeline
    _, _, pipeline_size, _ = pred_t.shape
    print(f" → Pipeline output: {pipeline_size}x{pipeline_size}", end="")

    # Convert tensor to PIL (square output at pipeline_size)
    pred_img = T.ToPILImage()(pred_t.squeeze(0).cpu())

    # Resize back to original dimensions using PIL LANCZOS (capped at pipeline_size)
    output_width, output_height = calculate_output_dimensions(orig_width, orig_height, pipeline_size)
    pred_img = pred_img.resize((output_width, output_height), resample=Image.LANCZOS)
    print(f" → Resized: {output_width}x{output_height}", end="")

    output_width, output_height = pred_img.size
    print(f" → Output: {output_width}x{output_height}")

    # Determine save paths
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    clean_name = f"{base_name}-clean.webp"

    # Create output folder and save both original and clean versions
    os.makedirs(output_folder, exist_ok=True)

    # Save original in output folder (keeps original extension)
    orig_save_path = os.path.join(output_folder, os.path.basename(img_path))
    img.save(orig_save_path)

    # Save clean version (webp format with -clean suffix)
    clean_path = os.path.join(output_folder, clean_name)
    pred_img.save(clean_path, 'WEBP', quality=95)


def main():
    # Enable TensorFloat32 for faster matmul on Ampere+ GPUs
    torch.set_float32_matmul_precision('high')

    args = parse_args()

    # Verify TorchScript model exists
    if not os.path.exists(args.model_path):
        print(f"Error: TorchScript model not found: {args.model_path}")
        return

    print(f"TorchScript Pipeline Inference")
    print(f"Model: {args.model_path}")
    print()

    # Load TorchScript pipeline once
    model, device = load_torchscript_model(args.model_path)
    print(f"Pipeline loaded on {device}")
    print()

    num_images = 0

    # Determine output folder based on processing mode
    if args.image:
        # Single image: save directly in output_folder
        output_path = args.output_folder

        # Start timing AFTER model loading
        start_time = time.time()

        process_image(args.image, model, device, output_path)
        num_images = 1
    elif args.folder:
        # Folder processing: create subfolder {model_name}_{folder_name}_ts
        model_name = os.path.splitext(os.path.basename(args.model_path))[0]
        folder_name = os.path.basename(os.path.normpath(args.folder))
        subfolder_name = f"{model_name}_{folder_name}_ts"
        output_path = os.path.join(args.output_folder, subfolder_name)

        print(f"Saving outputs to: {output_path}")
        print()

        # Process all JPG/WebP in folder
        patterns = ['*.jpg', '*.webp']
        images = []
        for pattern in patterns:
            images.extend(glob.glob(os.path.join(args.folder, pattern)))

        num_images = len(images)

        # Start timing AFTER model loading
        start_time = time.time()

        for img_path in sorted(images):
            process_image(img_path, model, device, output_path)

    # Print total processing time
    elapsed_time = time.time() - start_time
    print(f"\nProcessed {num_images} image{'s' if num_images != 1 else ''} in {elapsed_time:.2f} seconds ({elapsed_time/num_images:.2f}s per image)")


if __name__ == '__main__':
    main()