Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| from transformer_net import TransformerNet | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model from file | |
| def load_model(style_name): | |
| model_path = f"models/{style_name}.pth" | |
| model = TransformerNet() | |
| state_dict = torch.load(model_path, map_location=device) | |
| # Clean deprecated keys if necessary | |
| for k in list(state_dict.keys()): | |
| if "running_mean" in k or "running_var" in k: | |
| del state_dict[k] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| return model.eval() | |
| # Image loader and processor | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize(512), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: x.mul(255)) | |
| ]) | |
| return transform(image).unsqueeze(0).to(device) | |
| def postprocess_image(tensor): | |
| tensor = tensor.cpu().clone().squeeze(0) | |
| tensor = tensor.clamp(0, 255).div(255) | |
| image = transforms.ToPILImage()(tensor) | |
| return image | |
| # Style transfer pipeline | |
| def apply_style(content_img, style_name): | |
| content_tensor = preprocess_image(content_img) | |
| model = load_model(style_name) | |
| with torch.no_grad(): | |
| output_tensor = model(content_tensor) | |
| return postprocess_image(output_tensor) | |
| # Style options (pretrained models) | |
| style_choices = { | |
| "Mosaic": "mosaic", | |
| "Candy": "candy", | |
| "Rain Princess": "rain_princess", | |
| "Udnie": "udnie" | |
| } | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=lambda img, style: apply_style(img, style_choices[style]), | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Content Image"), | |
| gr.Dropdown(choices=list(style_choices.keys()), label="Choose Style") | |
| ], | |
| outputs=gr.Image(type="pil", label="Stylized Output"), | |
| title="🎨 Fast Neural Style Transfer", | |
| description="Upload an image and select a painting style to apply style transfer", | |
| theme = gr.themes.Soft(), | |
| examples=[ | |
| ["examples/amber.jpg", "Mosaic"], | |
| ["examples/sunset.jpg", "Mosaic"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(share=True,debug=True) | |