Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import tempfile | |
| # --------------------------- | |
| # MODEL ARCHITECTURE | |
| # --------------------------- | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(channels, channels, 3, 1, 1), | |
| nn.ReLU(), | |
| nn.Conv2d(channels, channels, 3, 1, 1) | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.entry = nn.Conv2d(3, 64, 3, 1, 1) | |
| self.res_blocks = nn.Sequential( | |
| ResidualBlock(64), | |
| ResidualBlock(64), | |
| ResidualBlock(64) | |
| ) | |
| self.exit = nn.Sequential( | |
| nn.Conv2d(64, 3, 3, 1, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| x = self.entry(x) | |
| x = self.res_blocks(x) | |
| return self.exit(x) | |
| # --------------------------- | |
| # LOAD MODEL | |
| # --------------------------- | |
| device = torch.device("cpu") | |
| model = Generator().to(device) | |
| checkpoint = torch.load("final_sr_model_v3.pth", map_location=device) | |
| model.load_state_dict(checkpoint['generator']) | |
| model.eval() | |
| # --------------------------- | |
| # TRANSFORM | |
| # --------------------------- | |
| transform = transforms.ToTensor() | |
| # --------------------------- | |
| # INFERENCE FUNCTION | |
| # --------------------------- | |
| def enhance_image(input_image): | |
| img = input_image.convert("RGB") | |
| original_size = img.size | |
| input_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| # Convert tensor → numpy | |
| output = output.squeeze().permute(1, 2, 0).cpu().numpy() | |
| # Handle range safely | |
| if output.min() < 0: | |
| output = (output + 1) / 2 | |
| output = np.clip(output, 0, 1) | |
| output_img = (output * 255).astype(np.uint8) | |
| # Resize back | |
| output_img = Image.fromarray(output_img) | |
| output_img = output_img.resize(original_size, Image.BICUBIC) | |
| output_img = np.array(output_img) | |
| # --------------------------- | |
| # FINAL BALANCED PROCESSING | |
| # --------------------------- | |
| # 1. Very light smoothing (remove artifacts) | |
| output_img = cv2.GaussianBlur(output_img, (3, 3), 0) | |
| # 2. Mild sharpening (safe) | |
| sharpen_kernel = np.array([ | |
| [0, -1, 0], | |
| [-1, 5, -1], | |
| [0, -1, 0] | |
| ]) | |
| output_img = cv2.filter2D(output_img, -1, sharpen_kernel) | |
| # 3. Color-safe blending (MOST IMPORTANT) | |
| original_np = np.array(img.resize(original_size)) | |
| output_img = cv2.addWeighted(original_np, 0.8, output_img, 0.2, 0) | |
| # 4. Very light contrast (SAFE — no color shift) | |
| alpha = 1.05 # slight contrast | |
| beta = 2 # slight brightness | |
| output_img = cv2.convertScaleAbs(output_img, alpha=alpha, beta=beta) | |
| output_img = np.clip(output_img, 0, 255) | |
| # Save for download | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| Image.fromarray(output_img).save(temp_file.name) | |
| return output_img, temp_file.name | |
| # --------------------------- | |
| # GRADIO UI | |
| # --------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🔍 AI Image Enhancer") | |
| gr.Markdown("Upload a low-quality image and enhance it using deep learning") | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="Upload Image") | |
| output_img = gr.Image(label="Enhanced Image") | |
| download_file = gr.File(label="Download Enhanced Image") | |
| btn = gr.Button("Enhance Image") | |
| btn.click( | |
| fn=enhance_image, | |
| inputs=input_img, | |
| outputs=[output_img, download_file] | |
| ) | |
| demo.launch() |