File size: 2,054 Bytes
626b231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
run.py — Apply your trained style to any photo

Usage:
    python run.py --model starry_night.pth --input my_photo.jpg --output result.jpg
    python run.py --model mosaic.pth --input my_photo.jpg --output result.jpg

No GPU needed — runs on CPU in under 1 second.
"""

import torch
from torchvision import transforms
from PIL import Image
import argparse
from model import StyleNet


def stylize(model_path, input_path, output_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on: {device}")

    # Load trained model
    model = StyleNet()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    model.to(device)

    # Load and prepare input image
    img = Image.open(input_path).convert("RGB")
    original_size = img.size           # save so we can restore it at the end
    print(f"Input image: {input_path}  ({img.width}x{img.height})")

    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    tensor = to_tensor(img).unsqueeze(0).to(device)    # shape: [1, 3, H, W]

    # Run inference
    with torch.no_grad():
        output = model(tensor).squeeze(0).clamp(0, 1)  # shape: [3, H, W]

    # Convert back to PIL image and save
    to_pil = transforms.ToPILImage()
    result = to_pil(output)
    result = result.resize(original_size, Image.LANCZOS)  # restore original size
    result.save(output_path, quality=95)

    print(f"Styled image saved to: {output_path}")
    print("Open the file to see your result!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",  required=True, help="Path to your .pth model file")
    parser.add_argument("--input",  required=True, help="Path to your input photo")
    parser.add_argument("--output", default="output.jpg", help="Where to save the result")
    args = parser.parse_args()
    stylize(args.model, args.input, args.output)