Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import os | |
| import torch.nn.functional as F | |
| # Check for CUDA availability but fallback to CPU | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| norm_layer = nn.InstanceNorm2d | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_features): | |
| super(ResidualBlock, self).__init__() | |
| conv_block = [ nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features, in_features, 3), | |
| norm_layer(in_features), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features, in_features, 3), | |
| norm_layer(in_features) ] | |
| self.conv_block = nn.Sequential(*conv_block) | |
| def forward(self, x): | |
| return x + self.conv_block(x) | |
| class Generator(nn.Module): | |
| def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): | |
| super(Generator, self).__init__() | |
| # Initial convolution block | |
| model0 = [ nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, 64, 7), | |
| norm_layer(64), | |
| nn.ReLU(inplace=True) ] | |
| self.model0 = nn.Sequential(*model0) | |
| # Downsampling | |
| model1 = [] | |
| in_features = 64 | |
| out_features = in_features*2 | |
| for _ in range(2): | |
| model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | |
| norm_layer(out_features), | |
| nn.ReLU(inplace=True) ] | |
| in_features = out_features | |
| out_features = in_features*2 | |
| self.model1 = nn.Sequential(*model1) | |
| # Residual blocks | |
| model2 = [] | |
| for _ in range(n_residual_blocks): | |
| model2 += [ResidualBlock(in_features)] | |
| self.model2 = nn.Sequential(*model2) | |
| # Upsampling | |
| model3 = [] | |
| out_features = in_features//2 | |
| for _ in range(2): | |
| model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), | |
| norm_layer(out_features), | |
| nn.ReLU(inplace=True) ] | |
| in_features = out_features | |
| out_features = in_features//2 | |
| self.model3 = nn.Sequential(*model3) | |
| # Output layer | |
| model4 = [ nn.ReflectionPad2d(3), | |
| nn.Conv2d(64, output_nc, 7)] | |
| if sigmoid: | |
| model4 += [nn.Sigmoid()] | |
| self.model4 = nn.Sequential(*model4) | |
| def forward(self, x): | |
| out = self.model0(x) | |
| out = self.model1(out) | |
| out = self.model2(out) | |
| out = self.model3(out) | |
| out = self.model4(out) | |
| return out | |
| # Initialize models | |
| def load_models(): | |
| try: | |
| model1 = Generator(3, 1, 3).to(device) | |
| model2 = Generator(3, 1, 3).to(device) | |
| # Load local model files | |
| model1.load_state_dict(torch.load('model.pth', map_location=device)) | |
| model2.load_state_dict(torch.load('model2.pth', map_location=device)) | |
| model1.eval() | |
| model2.eval() | |
| return model1, model2 | |
| except Exception as e: | |
| print(f"Error loading models: {str(e)}") | |
| raise gr.Error("Failed to load models. Please check if model files exist in the correct location.") | |
| try: | |
| model1, model2 = load_models() | |
| except Exception as e: | |
| print(f"Model initialization failed: {str(e)}") | |
| model1 = model2 = None | |
| def apply_style_transfer(img, strength=1.0): | |
| """Apply artistic style transfer effect""" | |
| img_array = np.array(img) | |
| processed = F.interpolate( | |
| torch.from_numpy(img_array).float().unsqueeze(0), | |
| size=(256, 256), | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| return processed * strength | |
| def enhance_lines(img, contrast=1.0, brightness=1.0): | |
| """Enhance line drawing with contrast and brightness adjustments""" | |
| enhanced = np.array(img) | |
| enhanced = enhanced * contrast | |
| enhanced = np.clip(enhanced + brightness, 0, 1) | |
| return Image.fromarray((enhanced * 255).astype(np.uint8)) | |
| def predict(input_img, version, line_thickness=1.0, contrast=1.0, brightness=1.0, enable_enhancement=False): | |
| try: | |
| # Open and process input image | |
| original_img = Image.open(input_img) | |
| original_size = original_img.size | |
| # Transform pipeline | |
| transform = transforms.Compose([ | |
| transforms.Resize(256, Image.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| input_tensor = transform(original_img).unsqueeze(0).to(device) | |
| # Process through selected model | |
| with torch.no_grad(): | |
| if version == 'Simple Lines': | |
| output = model2(input_tensor) | |
| else: | |
| output = model1(input_tensor) | |
| # Apply line thickness adjustment | |
| output = output * line_thickness | |
| # Convert to image | |
| output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) | |
| # Apply enhancements if enabled | |
| if enable_enhancement: | |
| output_img = enhance_lines(output_img, contrast, brightness) | |
| # Resize to original | |
| output_img = output_img.resize(original_size, Image.BICUBIC) | |
| return output_img | |
| except Exception as e: | |
| raise gr.Error(f"Error processing image: {str(e)}") | |
| # Custom CSS for better UI | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Helvetica Neue', Arial, sans-serif; | |
| } | |
| .gr-button { | |
| border-radius: 8px; | |
| background: linear-gradient(45deg, #3498db, #2980b9); | |
| border: none; | |
| color: white; | |
| } | |
| .gr-button:hover { | |
| background: linear-gradient(45deg, #2980b9, #3498db); | |
| transform: translateY(-2px); | |
| transition: all 0.3s ease; | |
| } | |
| .gr-input { | |
| border-radius: 8px; | |
| border: 2px solid #3498db; | |
| } | |
| """ | |
| # Create Gradio interface with enhanced UI | |
| with gr.Blocks(css=custom_css) as iface: | |
| gr.Markdown("# 🎨 Advanced Line Drawing Generator") | |
| gr.Markdown("Transform your images into beautiful line drawings with advanced controls") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="filepath", label="Upload Image") | |
| version = gr.Radio( | |
| choices=['Complex Lines', 'Simple Lines'], | |
| value='Simple Lines', | |
| label="Drawing Style" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| line_thickness = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Line Thickness" | |
| ) | |
| enable_enhancement = gr.Checkbox( | |
| label="Enable Enhancement", | |
| value=False | |
| ) | |
| with gr.Group(visible=False) as enhancement_controls: | |
| contrast = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Contrast" | |
| ) | |
| brightness = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.5, | |
| value=1.0, | |
| step=0.1, | |
| label="Brightness" | |
| ) | |
| enable_enhancement.change( | |
| fn=lambda x: gr.Group(visible=x), | |
| inputs=[enable_enhancement], | |
| outputs=[enhancement_controls] | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="Generated Line Drawing") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Drawing", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| # Load example images | |
| example_images = [] | |
| for file in os.listdir('.'): | |
| if file.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| example_images.append(file) | |
| if example_images: | |
| gr.Examples( | |
| examples=[[img, "Simple Lines"] for img in example_images], | |
| inputs=[input_image, version], | |
| outputs=output_image, | |
| fn=predict, | |
| cache_examples=True | |
| ) | |
| # Set up event handlers | |
| generate_btn.click( | |
| fn=predict, | |
| inputs=[ | |
| input_image, | |
| version, | |
| line_thickness, | |
| contrast, | |
| brightness, | |
| enable_enhancement | |
| ], | |
| outputs=output_image | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "Simple Lines", 1.0, 1.0, 1.0, False), | |
| inputs=[], | |
| outputs=[ | |
| input_image, | |
| version, | |
| line_thickness, | |
| contrast, | |
| brightness, | |
| enable_enhancement | |
| ] | |
| ) | |
| # Launch the interface | |
| iface.launch() |