File size: 2,200 Bytes
daf9d89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e364ae
daf9d89
 
 
 
 
 
 
 
 
 
c163ea5
d97814c
 
d4d5245
d97814c
daf9d89
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
from transformer_net import TransformerNet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model from file
def load_model(style_name):
    model_path = f"models/{style_name}.pth"
    model = TransformerNet()
    state_dict = torch.load(model_path, map_location=device)

    # Clean deprecated keys if necessary
    for k in list(state_dict.keys()):
        if "running_mean" in k or "running_var" in k:
            del state_dict[k]

    model.load_state_dict(state_dict)
    model.to(device)
    return model.eval()

# Image loader and processor
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize(512),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    return transform(image).unsqueeze(0).to(device)

def postprocess_image(tensor):
    tensor = tensor.cpu().clone().squeeze(0)
    tensor = tensor.clamp(0, 255).div(255)
    image = transforms.ToPILImage()(tensor)
    return image

# Style transfer pipeline
def apply_style(content_img, style_name):
    content_tensor = preprocess_image(content_img)
    model = load_model(style_name)
    with torch.no_grad():
        output_tensor = model(content_tensor)
    return postprocess_image(output_tensor)

# Style options (pretrained models)
style_choices = {
    "Mosaic": "mosaic",
    "Candy": "candy",
    "Rain Princess": "rain_princess",
    "Udnie": "udnie"
}

# Gradio interface
interface = gr.Interface(
    fn=lambda img, style: apply_style(img, style_choices[style]),
    inputs=[
        gr.Image(type="pil", label="Upload Content Image"),
        gr.Dropdown(choices=list(style_choices.keys()), label="Choose Style")
    ],
    outputs=gr.Image(type="pil", label="Stylized Output"),
    title="🎨 Fast Neural Style Transfer",
    description="Upload an image and select a painting style to apply  style transfer",
    theme = gr.themes.Soft(),
    examples=[
        ["examples/amber.jpg", "Mosaic"],
        ["examples/sunset.jpg", "Mosaic"]
    ]
)

if __name__ == "__main__":
    interface.launch(share=True,debug=True)