Spaces:
Running
Running
File size: 3,777 Bytes
3dbf5d4 5acf400 3dbf5d4 e2ec1ae 3dbf5d4 133c244 3dbf5d4 f80e710 3dbf5d4 f80e710 3dbf5d4 5acf400 3dbf5d4 f80e710 3dbf5d4 78e35d6 133c244 f80e710 e2ec1ae 133c244 5acf400 8b91a17 78e35d6 8b91a17 ab9ab5b 78e35d6 f80e710 ab9ab5b e2ec1ae 8b91a17 133c244 8b91a17 e2ec1ae 8b91a17 78e35d6 133c244 78e35d6 e2ec1ae 78e35d6 e2ec1ae f80e710 5acf400 133c244 3c9c540 f80e710 3dbf5d4 f80e710 8b91a17 5acf400 3dbf5d4 5acf400 3c9c540 5acf400 3dbf5d4 3c9c540 3dbf5d4 5acf400 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import tempfile
# ---------------------------
# MODEL ARCHITECTURE
# ---------------------------
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(channels, channels, 3, 1, 1)
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.entry = nn.Conv2d(3, 64, 3, 1, 1)
self.res_blocks = nn.Sequential(
ResidualBlock(64),
ResidualBlock(64),
ResidualBlock(64)
)
self.exit = nn.Sequential(
nn.Conv2d(64, 3, 3, 1, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.entry(x)
x = self.res_blocks(x)
return self.exit(x)
# ---------------------------
# LOAD MODEL
# ---------------------------
device = torch.device("cpu")
model = Generator().to(device)
checkpoint = torch.load("final_sr_model_v3.pth", map_location=device)
model.load_state_dict(checkpoint['generator'])
model.eval()
# ---------------------------
# TRANSFORM
# ---------------------------
transform = transforms.ToTensor()
# ---------------------------
# INFERENCE FUNCTION
# ---------------------------
def enhance_image(input_image):
img = input_image.convert("RGB")
original_size = img.size
input_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
# Convert tensor → numpy
output = output.squeeze().permute(1, 2, 0).cpu().numpy()
# Handle range safely
if output.min() < 0:
output = (output + 1) / 2
output = np.clip(output, 0, 1)
output_img = (output * 255).astype(np.uint8)
# Resize back
output_img = Image.fromarray(output_img)
output_img = output_img.resize(original_size, Image.BICUBIC)
output_img = np.array(output_img)
# ---------------------------
# FINAL BALANCED PROCESSING
# ---------------------------
# 1. Very light smoothing (remove artifacts)
output_img = cv2.GaussianBlur(output_img, (3, 3), 0)
# 2. Mild sharpening (safe)
sharpen_kernel = np.array([
[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]
])
output_img = cv2.filter2D(output_img, -1, sharpen_kernel)
# 3. Color-safe blending (MOST IMPORTANT)
original_np = np.array(img.resize(original_size))
output_img = cv2.addWeighted(original_np, 0.8, output_img, 0.2, 0)
# 4. Very light contrast (SAFE — no color shift)
alpha = 1.05 # slight contrast
beta = 2 # slight brightness
output_img = cv2.convertScaleAbs(output_img, alpha=alpha, beta=beta)
output_img = np.clip(output_img, 0, 255)
# Save for download
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
Image.fromarray(output_img).save(temp_file.name)
return output_img, temp_file.name
# ---------------------------
# GRADIO UI
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🔍 AI Image Enhancer")
gr.Markdown("Upload a low-quality image and enhance it using deep learning")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload Image")
output_img = gr.Image(label="Enhanced Image")
download_file = gr.File(label="Download Enhanced Image")
btn = gr.Button("Enhance Image")
btn.click(
fn=enhance_image,
inputs=input_img,
outputs=[output_img, download_file]
)
demo.launch() |