import gradio as gr import torch import numpy as np from PIL import Image from torchvision.transforms import ToTensor, ToPILImage from urllib.request import urlretrieve import os # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Constants MODEL_URL = "https://github.com/xinntao/ESRGAN/releases/download/v0.1.1/RRDB_ESRGAN_x4.pth" MODEL_PATH = "RRDB_ESRGAN_x4.pth" MAX_IMAGE_SIZE = (1024, 1024) # ESRGAN model architecture class RRDBNet(torch.nn.Module): def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32): super(RRDBNet, self).__init__() self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = torch.nn.ModuleList([RRDB(nf, gc=gc) for _ in range(nb)]) self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = fea.clone() for block in self.RRDB_trunk: trunk = block(trunk) trunk = self.trunk_conv(trunk) fea = fea + trunk fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out class RRDB(torch.nn.Module): def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock(nf, gc) self.RDB2 = ResidualDenseBlock(nf, gc) self.RDB3 = ResidualDenseBlock(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class ResidualDenseBlock(torch.nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock, self).__init__() self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = torch.nn.Conv2d(nf + 2*gc, gc, 3, 1, 1, bias=bias) self.conv4 = torch.nn.Conv2d(nf + 3*gc, gc, 3, 1, 1, bias=bias) self.conv5 = torch.nn.Conv2d(nf + 4*gc, nf, 3, 1, 1, bias=bias) self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x def load_model() -> torch.nn.Module: """Download and load ESRGAN model""" if not os.path.exists(MODEL_PATH): print("Downloading ESRGAN model...") urlretrieve(MODEL_URL, MODEL_PATH) model = RRDBNet() state_dict = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(state_dict) return model.to(device).eval() def preprocess_image(image: Image.Image) -> torch.Tensor: """Convert PIL image to normalized tensor""" transform = ToTensor() return transform(image).unsqueeze(0).to(device) def postprocess_image(tensor: torch.Tensor) -> Image.Image: """Convert tensor to PIL image""" transform = ToPILImage() tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1) return transform(tensor) def validate_image(image: Image.Image): """Validate input image constraints""" if image.mode not in ["RGB", "RGBA"]: raise gr.Error("Only RGB/RGBA images supported") if max(image.size) > max(MAX_IMAGE_SIZE): raise gr.Error(f"Max image dimension exceeded ({MAX_IMAGE_SIZE[0]}x{MAX_IMAGE_SIZE[1]})") def enhance_image( input_image: Image.Image, scale_factor: float = 2.0 ) -> Image.Image: """Main processing function""" try: validate_image(input_image) # Convert RGBA to RGB if input_image.mode == 'RGBA': input_image = input_image.convert('RGB') with torch.no_grad(): input_tensor = preprocess_image(input_image) output_tensor = model(input_tensor) result = postprocess_image(output_tensor) return result.resize( (int(input_image.width*scale_factor), int(input_image.height*scale_factor)), Image.LANCZOS ) except Exception as e: raise gr.Error(f"Processing error: {str(e)}") # Initialize model model = load_model() # Gradio interface interface = gr.Interface( fn=enhance_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Slider(2.0, 4.0, 2.0, step=2.0, label="Scale Factor") ], outputs=gr.Image(type="pil", label="Enhanced Image"), title="🎨 AI Image Enhancer", examples=[["examples/example1.jpg", 2.0]], css=".gradio-container {max-width: 800px !important}" ) if __name__ == "__main__": interface.launch(server_name="0.0.0.0")