File size: 5,380 Bytes
f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 0ed552c 8c70566 0ed552c 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b ecf2564 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b ecf2564 8c70566 ecf2564 f73e05b ecf2564 8c70566 ecf2564 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 f73e05b 8c70566 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
from urllib.request import urlretrieve
import os
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Constants
MODEL_URL = "https://github.com/xinntao/ESRGAN/releases/download/v0.1.1/RRDB_ESRGAN_x4.pth"
MODEL_PATH = "RRDB_ESRGAN_x4.pth"
MAX_IMAGE_SIZE = (1024, 1024)
# ESRGAN model architecture
class RRDBNet(torch.nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = torch.nn.ModuleList([RRDB(nf, gc=gc) for _ in range(nb)])
self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = fea.clone()
for block in self.RRDB_trunk:
trunk = block(trunk)
trunk = self.trunk_conv(trunk)
fea = fea + trunk
fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
class RRDB(torch.nn.Module):
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock(nf, gc)
self.RDB2 = ResidualDenseBlock(nf, gc)
self.RDB3 = ResidualDenseBlock(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class ResidualDenseBlock(torch.nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = torch.nn.Conv2d(nf + 2*gc, gc, 3, 1, 1, bias=bias)
self.conv4 = torch.nn.Conv2d(nf + 3*gc, gc, 3, 1, 1, bias=bias)
self.conv5 = torch.nn.Conv2d(nf + 4*gc, nf, 3, 1, 1, bias=bias)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
def load_model() -> torch.nn.Module:
"""Download and load ESRGAN model"""
if not os.path.exists(MODEL_PATH):
print("Downloading ESRGAN model...")
urlretrieve(MODEL_URL, MODEL_PATH)
model = RRDBNet()
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
return model.to(device).eval()
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""Convert PIL image to normalized tensor"""
transform = ToTensor()
return transform(image).unsqueeze(0).to(device)
def postprocess_image(tensor: torch.Tensor) -> Image.Image:
"""Convert tensor to PIL image"""
transform = ToPILImage()
tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1)
return transform(tensor)
def validate_image(image: Image.Image):
"""Validate input image constraints"""
if image.mode not in ["RGB", "RGBA"]:
raise gr.Error("Only RGB/RGBA images supported")
if max(image.size) > max(MAX_IMAGE_SIZE):
raise gr.Error(f"Max image dimension exceeded ({MAX_IMAGE_SIZE[0]}x{MAX_IMAGE_SIZE[1]})")
def enhance_image(
input_image: Image.Image,
scale_factor: float = 2.0
) -> Image.Image:
"""Main processing function"""
try:
validate_image(input_image)
# Convert RGBA to RGB
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
with torch.no_grad():
input_tensor = preprocess_image(input_image)
output_tensor = model(input_tensor)
result = postprocess_image(output_tensor)
return result.resize(
(int(input_image.width*scale_factor),
int(input_image.height*scale_factor)),
Image.LANCZOS
)
except Exception as e:
raise gr.Error(f"Processing error: {str(e)}")
# Initialize model
model = load_model()
# Gradio interface
interface = gr.Interface(
fn=enhance_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Slider(2.0, 4.0, 2.0, step=2.0, label="Scale Factor")
],
outputs=gr.Image(type="pil", label="Enhanced Image"),
title="🎨 AI Image Enhancer",
examples=[["examples/example1.jpg", 2.0]],
css=".gradio-container {max-width: 800px !important}"
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0") |