|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torchvision.transforms import ToTensor, ToPILImage |
|
|
from urllib.request import urlretrieve |
|
|
import os |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
MODEL_URL = "https://github.com/xinntao/ESRGAN/releases/download/v0.1.1/RRDB_ESRGAN_x4.pth" |
|
|
MODEL_PATH = "RRDB_ESRGAN_x4.pth" |
|
|
MAX_IMAGE_SIZE = (1024, 1024) |
|
|
|
|
|
|
|
|
class RRDBNet(torch.nn.Module): |
|
|
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32): |
|
|
super(RRDBNet, self).__init__() |
|
|
self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) |
|
|
self.RRDB_trunk = torch.nn.ModuleList([RRDB(nf, gc=gc) for _ in range(nb)]) |
|
|
self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) |
|
|
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
fea = self.conv_first(x) |
|
|
trunk = fea.clone() |
|
|
for block in self.RRDB_trunk: |
|
|
trunk = block(trunk) |
|
|
trunk = self.trunk_conv(trunk) |
|
|
fea = fea + trunk |
|
|
fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) |
|
|
fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) |
|
|
out = self.conv_last(self.lrelu(self.HRconv(fea))) |
|
|
return out |
|
|
|
|
|
class RRDB(torch.nn.Module): |
|
|
def __init__(self, nf, gc=32): |
|
|
super(RRDB, self).__init__() |
|
|
self.RDB1 = ResidualDenseBlock(nf, gc) |
|
|
self.RDB2 = ResidualDenseBlock(nf, gc) |
|
|
self.RDB3 = ResidualDenseBlock(nf, gc) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.RDB1(x) |
|
|
out = self.RDB2(out) |
|
|
out = self.RDB3(out) |
|
|
return out * 0.2 + x |
|
|
|
|
|
class ResidualDenseBlock(torch.nn.Module): |
|
|
def __init__(self, nf=64, gc=32, bias=True): |
|
|
super(ResidualDenseBlock, self).__init__() |
|
|
self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) |
|
|
self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) |
|
|
self.conv3 = torch.nn.Conv2d(nf + 2*gc, gc, 3, 1, 1, bias=bias) |
|
|
self.conv4 = torch.nn.Conv2d(nf + 3*gc, gc, 3, 1, 1, bias=bias) |
|
|
self.conv5 = torch.nn.Conv2d(nf + 4*gc, nf, 3, 1, 1, bias=bias) |
|
|
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.lrelu(self.conv1(x)) |
|
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) |
|
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) |
|
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) |
|
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
|
|
return x5 * 0.2 + x |
|
|
|
|
|
def load_model() -> torch.nn.Module: |
|
|
"""Download and load ESRGAN model""" |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
print("Downloading ESRGAN model...") |
|
|
urlretrieve(MODEL_URL, MODEL_PATH) |
|
|
|
|
|
model = RRDBNet() |
|
|
state_dict = torch.load(MODEL_PATH, map_location=device) |
|
|
model.load_state_dict(state_dict) |
|
|
return model.to(device).eval() |
|
|
|
|
|
def preprocess_image(image: Image.Image) -> torch.Tensor: |
|
|
"""Convert PIL image to normalized tensor""" |
|
|
transform = ToTensor() |
|
|
return transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
def postprocess_image(tensor: torch.Tensor) -> Image.Image: |
|
|
"""Convert tensor to PIL image""" |
|
|
transform = ToPILImage() |
|
|
tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1) |
|
|
return transform(tensor) |
|
|
|
|
|
def validate_image(image: Image.Image): |
|
|
"""Validate input image constraints""" |
|
|
if image.mode not in ["RGB", "RGBA"]: |
|
|
raise gr.Error("Only RGB/RGBA images supported") |
|
|
if max(image.size) > max(MAX_IMAGE_SIZE): |
|
|
raise gr.Error(f"Max image dimension exceeded ({MAX_IMAGE_SIZE[0]}x{MAX_IMAGE_SIZE[1]})") |
|
|
|
|
|
def enhance_image( |
|
|
input_image: Image.Image, |
|
|
scale_factor: float = 2.0 |
|
|
) -> Image.Image: |
|
|
"""Main processing function""" |
|
|
try: |
|
|
validate_image(input_image) |
|
|
|
|
|
|
|
|
if input_image.mode == 'RGBA': |
|
|
input_image = input_image.convert('RGB') |
|
|
|
|
|
with torch.no_grad(): |
|
|
input_tensor = preprocess_image(input_image) |
|
|
output_tensor = model(input_tensor) |
|
|
|
|
|
result = postprocess_image(output_tensor) |
|
|
return result.resize( |
|
|
(int(input_image.width*scale_factor), |
|
|
int(input_image.height*scale_factor)), |
|
|
Image.LANCZOS |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error(f"Processing error: {str(e)}") |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=enhance_image, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Input Image"), |
|
|
gr.Slider(2.0, 4.0, 2.0, step=2.0, label="Scale Factor") |
|
|
], |
|
|
outputs=gr.Image(type="pil", label="Enhanced Image"), |
|
|
title="🎨 AI Image Enhancer", |
|
|
examples=[["examples/example1.jpg", 2.0]], |
|
|
css=".gradio-container {max-width: 800px !important}" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch(server_name="0.0.0.0") |