CDAN / app.py
Hossshakiba's picture
Update app.py
7c73954 verified
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
from cdan import CDAN
def load_model():
model_repo = "hossshakiba/CDAN"
model_path = hf_hub_download(repo_id=model_repo, filename="CDAN.pt")
model = CDAN()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
return model
model = load_model()
# Preprocessing and postprocessing
preprocess = transforms.Compose([
transforms.ToTensor(), # Convert PIL Image to tensor (0-1 range)
transforms.Resize((400, 600)), # Adjust size as needed
])
def enhance_contrast(images, contrast_factor=1.5):
if images.max() > 1.0:
images = images / 255.0
mean_intensity = images.mean(dim=(2, 3), keepdim=True)
enhanced_images = (images - mean_intensity) * contrast_factor + mean_intensity
enhanced_images = torch.clamp(enhanced_images, 0.0, 1.0)
return enhanced_images
def enhance_color(images, saturation_factor=1.5):
if images.max() > 1.0:
images = images / 255.0
grayscale = 0.2989 * images[:, 0, :, :] + 0.5870 * images[:, 1, :, :] + 0.1140 * images[:, 2, :, :]
grayscale = grayscale.unsqueeze(1) # Add channel dimension
enhanced_images = grayscale + saturation_factor * (images - grayscale)
enhanced_images = torch.clamp(enhanced_images, 0.0, 1.0)
return enhanced_images
# Inference function
def process_image(input_image):
# Convert input (PIL Image) to tensor
input_tensor = preprocess(input_image).unsqueeze(0) # Add batch dimension
# Run model
with torch.no_grad():
output_tensor = model(input_tensor)
# Post-processing (optional, based on your test code)
output_tensor = enhance_contrast(output_tensor, contrast_factor=1.12)
output_tensor = enhance_color(output_tensor, saturation_factor=1.35)
# Convert tensor back to PIL Image
output_tensor = output_tensor.squeeze(0).clamp(0, 1) # Remove batch dim, clamp to 0-1
output_image = transforms.ToPILImage()(output_tensor)
return output_image
# Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Image(type="pil", label="Enhanced Image"),
title="Low-light Image Enhancement",
description="CDAN: Convolutional Dense Attention-guided Network for Low-light Image Enhancement, 2024",
examples=[
["examples/example1.png"],
["examples/example2.png"],
["examples/example3.png"],
["examples/example4.png"],
["examples/example5.png"],
["examples/example6.png"],
["examples/example7.png"],
["examples/example8.png"],
["examples/example9.png"],
["examples/example10.png"],
["examples/example11.jpg"],
["examples/example12.png"]
]
)
interface.launch()