import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import cv2 import tempfile # --------------------------- # MODEL ARCHITECTURE # --------------------------- class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.block = nn.Sequential( nn.Conv2d(channels, channels, 3, 1, 1), nn.ReLU(), nn.Conv2d(channels, channels, 3, 1, 1) ) def forward(self, x): return x + self.block(x) class Generator(nn.Module): def __init__(self): super().__init__() self.entry = nn.Conv2d(3, 64, 3, 1, 1) self.res_blocks = nn.Sequential( ResidualBlock(64), ResidualBlock(64), ResidualBlock(64) ) self.exit = nn.Sequential( nn.Conv2d(64, 3, 3, 1, 1), nn.Sigmoid() ) def forward(self, x): x = self.entry(x) x = self.res_blocks(x) return self.exit(x) # --------------------------- # LOAD MODEL # --------------------------- device = torch.device("cpu") model = Generator().to(device) checkpoint = torch.load("final_sr_model_v3.pth", map_location=device) model.load_state_dict(checkpoint['generator']) model.eval() # --------------------------- # TRANSFORM # --------------------------- transform = transforms.ToTensor() # --------------------------- # INFERENCE FUNCTION # --------------------------- def enhance_image(input_image): img = input_image.convert("RGB") original_size = img.size input_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) # Convert tensor → numpy output = output.squeeze().permute(1, 2, 0).cpu().numpy() # Handle range safely if output.min() < 0: output = (output + 1) / 2 output = np.clip(output, 0, 1) output_img = (output * 255).astype(np.uint8) # Resize back output_img = Image.fromarray(output_img) output_img = output_img.resize(original_size, Image.BICUBIC) output_img = np.array(output_img) # --------------------------- # FINAL BALANCED PROCESSING # --------------------------- # 1. Very light smoothing (remove artifacts) output_img = cv2.GaussianBlur(output_img, (3, 3), 0) # 2. Mild sharpening (safe) sharpen_kernel = np.array([ [0, -1, 0], [-1, 5, -1], [0, -1, 0] ]) output_img = cv2.filter2D(output_img, -1, sharpen_kernel) # 3. Color-safe blending (MOST IMPORTANT) original_np = np.array(img.resize(original_size)) output_img = cv2.addWeighted(original_np, 0.8, output_img, 0.2, 0) # 4. Very light contrast (SAFE — no color shift) alpha = 1.05 # slight contrast beta = 2 # slight brightness output_img = cv2.convertScaleAbs(output_img, alpha=alpha, beta=beta) output_img = np.clip(output_img, 0, 255) # Save for download temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") Image.fromarray(output_img).save(temp_file.name) return output_img, temp_file.name # --------------------------- # GRADIO UI # --------------------------- with gr.Blocks() as demo: gr.Markdown("# 🔍 AI Image Enhancer") gr.Markdown("Upload a low-quality image and enhance it using deep learning") with gr.Row(): input_img = gr.Image(type="pil", label="Upload Image") output_img = gr.Image(label="Enhanced Image") download_file = gr.File(label="Download Enhanced Image") btn = gr.Button("Enhance Image") btn.click( fn=enhance_image, inputs=input_img, outputs=[output_img, download_file] ) demo.launch()