esrgan / app.py
devbernie's picture
Update
0ed552c verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
from urllib.request import urlretrieve
import os
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Constants
MODEL_URL = "https://github.com/xinntao/ESRGAN/releases/download/v0.1.1/RRDB_ESRGAN_x4.pth"
MODEL_PATH = "RRDB_ESRGAN_x4.pth"
MAX_IMAGE_SIZE = (1024, 1024)
# ESRGAN model architecture
class RRDBNet(torch.nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = torch.nn.ModuleList([RRDB(nf, gc=gc) for _ in range(nb)])
self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = fea.clone()
for block in self.RRDB_trunk:
trunk = block(trunk)
trunk = self.trunk_conv(trunk)
fea = fea + trunk
fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
class RRDB(torch.nn.Module):
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock(nf, gc)
self.RDB2 = ResidualDenseBlock(nf, gc)
self.RDB3 = ResidualDenseBlock(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class ResidualDenseBlock(torch.nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = torch.nn.Conv2d(nf + 2*gc, gc, 3, 1, 1, bias=bias)
self.conv4 = torch.nn.Conv2d(nf + 3*gc, gc, 3, 1, 1, bias=bias)
self.conv5 = torch.nn.Conv2d(nf + 4*gc, nf, 3, 1, 1, bias=bias)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
def load_model() -> torch.nn.Module:
"""Download and load ESRGAN model"""
if not os.path.exists(MODEL_PATH):
print("Downloading ESRGAN model...")
urlretrieve(MODEL_URL, MODEL_PATH)
model = RRDBNet()
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
return model.to(device).eval()
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""Convert PIL image to normalized tensor"""
transform = ToTensor()
return transform(image).unsqueeze(0).to(device)
def postprocess_image(tensor: torch.Tensor) -> Image.Image:
"""Convert tensor to PIL image"""
transform = ToPILImage()
tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1)
return transform(tensor)
def validate_image(image: Image.Image):
"""Validate input image constraints"""
if image.mode not in ["RGB", "RGBA"]:
raise gr.Error("Only RGB/RGBA images supported")
if max(image.size) > max(MAX_IMAGE_SIZE):
raise gr.Error(f"Max image dimension exceeded ({MAX_IMAGE_SIZE[0]}x{MAX_IMAGE_SIZE[1]})")
def enhance_image(
input_image: Image.Image,
scale_factor: float = 2.0
) -> Image.Image:
"""Main processing function"""
try:
validate_image(input_image)
# Convert RGBA to RGB
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
with torch.no_grad():
input_tensor = preprocess_image(input_image)
output_tensor = model(input_tensor)
result = postprocess_image(output_tensor)
return result.resize(
(int(input_image.width*scale_factor),
int(input_image.height*scale_factor)),
Image.LANCZOS
)
except Exception as e:
raise gr.Error(f"Processing error: {str(e)}")
# Initialize model
model = load_model()
# Gradio interface
interface = gr.Interface(
fn=enhance_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Slider(2.0, 4.0, 2.0, step=2.0, label="Scale Factor")
],
outputs=gr.Image(type="pil", label="Enhanced Image"),
title="🎨 AI Image Enhancer",
examples=[["examples/example1.jpg", 2.0]],
css=".gradio-container {max-width: 800px !important}"
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0")