dev1461's picture
Update app.py
78e35d6 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import tempfile
# ---------------------------
# MODEL ARCHITECTURE
# ---------------------------
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(channels, channels, 3, 1, 1)
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.entry = nn.Conv2d(3, 64, 3, 1, 1)
self.res_blocks = nn.Sequential(
ResidualBlock(64),
ResidualBlock(64),
ResidualBlock(64)
)
self.exit = nn.Sequential(
nn.Conv2d(64, 3, 3, 1, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.entry(x)
x = self.res_blocks(x)
return self.exit(x)
# ---------------------------
# LOAD MODEL
# ---------------------------
device = torch.device("cpu")
model = Generator().to(device)
checkpoint = torch.load("final_sr_model_v3.pth", map_location=device)
model.load_state_dict(checkpoint['generator'])
model.eval()
# ---------------------------
# TRANSFORM
# ---------------------------
transform = transforms.ToTensor()
# ---------------------------
# INFERENCE FUNCTION
# ---------------------------
def enhance_image(input_image):
img = input_image.convert("RGB")
original_size = img.size
input_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
# Convert tensor → numpy
output = output.squeeze().permute(1, 2, 0).cpu().numpy()
# Handle range safely
if output.min() < 0:
output = (output + 1) / 2
output = np.clip(output, 0, 1)
output_img = (output * 255).astype(np.uint8)
# Resize back
output_img = Image.fromarray(output_img)
output_img = output_img.resize(original_size, Image.BICUBIC)
output_img = np.array(output_img)
# ---------------------------
# FINAL BALANCED PROCESSING
# ---------------------------
# 1. Very light smoothing (remove artifacts)
output_img = cv2.GaussianBlur(output_img, (3, 3), 0)
# 2. Mild sharpening (safe)
sharpen_kernel = np.array([
[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]
])
output_img = cv2.filter2D(output_img, -1, sharpen_kernel)
# 3. Color-safe blending (MOST IMPORTANT)
original_np = np.array(img.resize(original_size))
output_img = cv2.addWeighted(original_np, 0.8, output_img, 0.2, 0)
# 4. Very light contrast (SAFE — no color shift)
alpha = 1.05 # slight contrast
beta = 2 # slight brightness
output_img = cv2.convertScaleAbs(output_img, alpha=alpha, beta=beta)
output_img = np.clip(output_img, 0, 255)
# Save for download
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
Image.fromarray(output_img).save(temp_file.name)
return output_img, temp_file.name
# ---------------------------
# GRADIO UI
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🔍 AI Image Enhancer")
gr.Markdown("Upload a low-quality image and enhance it using deep learning")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload Image")
output_img = gr.Image(label="Enhanced Image")
download_file = gr.File(label="Download Enhanced Image")
btn = gr.Button("Enhance Image")
btn.click(
fn=enhance_image,
inputs=input_img,
outputs=[output_img, download_file]
)
demo.launch()