Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| from PIL import Image | |
| from cdan import CDAN | |
| def load_model(): | |
| model_repo = "hossshakiba/CDAN" | |
| model_path = hf_hub_download(repo_id=model_repo, filename="CDAN.pt") | |
| model = CDAN() | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # Preprocessing and postprocessing | |
| preprocess = transforms.Compose([ | |
| transforms.ToTensor(), # Convert PIL Image to tensor (0-1 range) | |
| transforms.Resize((400, 600)), # Adjust size as needed | |
| ]) | |
| def enhance_contrast(images, contrast_factor=1.5): | |
| if images.max() > 1.0: | |
| images = images / 255.0 | |
| mean_intensity = images.mean(dim=(2, 3), keepdim=True) | |
| enhanced_images = (images - mean_intensity) * contrast_factor + mean_intensity | |
| enhanced_images = torch.clamp(enhanced_images, 0.0, 1.0) | |
| return enhanced_images | |
| def enhance_color(images, saturation_factor=1.5): | |
| if images.max() > 1.0: | |
| images = images / 255.0 | |
| grayscale = 0.2989 * images[:, 0, :, :] + 0.5870 * images[:, 1, :, :] + 0.1140 * images[:, 2, :, :] | |
| grayscale = grayscale.unsqueeze(1) # Add channel dimension | |
| enhanced_images = grayscale + saturation_factor * (images - grayscale) | |
| enhanced_images = torch.clamp(enhanced_images, 0.0, 1.0) | |
| return enhanced_images | |
| # Inference function | |
| def process_image(input_image): | |
| # Convert input (PIL Image) to tensor | |
| input_tensor = preprocess(input_image).unsqueeze(0) # Add batch dimension | |
| # Run model | |
| with torch.no_grad(): | |
| output_tensor = model(input_tensor) | |
| # Post-processing (optional, based on your test code) | |
| output_tensor = enhance_contrast(output_tensor, contrast_factor=1.12) | |
| output_tensor = enhance_color(output_tensor, saturation_factor=1.35) | |
| # Convert tensor back to PIL Image | |
| output_tensor = output_tensor.squeeze(0).clamp(0, 1) # Remove batch dim, clamp to 0-1 | |
| output_image = transforms.ToPILImage()(output_tensor) | |
| return output_image | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", label="Upload an Image"), | |
| outputs=gr.Image(type="pil", label="Enhanced Image"), | |
| title="Low-light Image Enhancement", | |
| description="CDAN: Convolutional Dense Attention-guided Network for Low-light Image Enhancement, 2024", | |
| examples=[ | |
| ["examples/example1.png"], | |
| ["examples/example2.png"], | |
| ["examples/example3.png"], | |
| ["examples/example4.png"], | |
| ["examples/example5.png"], | |
| ["examples/example6.png"], | |
| ["examples/example7.png"], | |
| ["examples/example8.png"], | |
| ["examples/example9.png"], | |
| ["examples/example10.png"], | |
| ["examples/example11.jpg"], | |
| ["examples/example12.png"] | |
| ] | |
| ) | |
| interface.launch() |