File size: 5,380 Bytes
f73e05b
 
 
 
 
8c70566
 
f73e05b
 
 
 
 
8c70566
 
f73e05b
 
8c70566
 
 
 
0ed552c
 
 
 
 
 
 
 
8c70566
0ed552c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c70566
f73e05b
8c70566
 
 
 
 
 
 
 
f73e05b
 
 
8c70566
f73e05b
8c70566
f73e05b
 
8c70566
f73e05b
ecf2564
f73e05b
 
8c70566
 
f73e05b
 
8c70566
 
f73e05b
 
 
 
 
8c70566
f73e05b
 
ecf2564
8c70566
ecf2564
 
f73e05b
 
 
 
 
ecf2564
8c70566
 
 
ecf2564
 
8c70566
f73e05b
8c70566
f73e05b
8c70566
f73e05b
 
8c70566
f73e05b
 
 
8c70566
 
f73e05b
8c70566
 
 
 
f73e05b
 
 
8c70566
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
from urllib.request import urlretrieve
import os

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants
MODEL_URL = "https://github.com/xinntao/ESRGAN/releases/download/v0.1.1/RRDB_ESRGAN_x4.pth"
MODEL_PATH = "RRDB_ESRGAN_x4.pth"
MAX_IMAGE_SIZE = (1024, 1024)

# ESRGAN model architecture
class RRDBNet(torch.nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
        super(RRDBNet, self).__init__()
        self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = torch.nn.ModuleList([RRDB(nf, gc=gc) for _ in range(nb)])
        self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = fea.clone()
        for block in self.RRDB_trunk:
            trunk = block(trunk)
        trunk = self.trunk_conv(trunk)
        fea = fea + trunk
        fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))
        return out

class RRDB(torch.nn.Module):
    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock(nf, gc)
        self.RDB2 = ResidualDenseBlock(nf, gc)
        self.RDB3 = ResidualDenseBlock(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x

class ResidualDenseBlock(torch.nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = torch.nn.Conv2d(nf + 2*gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = torch.nn.Conv2d(nf + 3*gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = torch.nn.Conv2d(nf + 4*gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

def load_model() -> torch.nn.Module:
    """Download and load ESRGAN model"""
    if not os.path.exists(MODEL_PATH):
        print("Downloading ESRGAN model...")
        urlretrieve(MODEL_URL, MODEL_PATH)
    
    model = RRDBNet()
    state_dict = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(state_dict)
    return model.to(device).eval()

def preprocess_image(image: Image.Image) -> torch.Tensor:
    """Convert PIL image to normalized tensor"""
    transform = ToTensor()
    return transform(image).unsqueeze(0).to(device)

def postprocess_image(tensor: torch.Tensor) -> Image.Image:
    """Convert tensor to PIL image"""
    transform = ToPILImage()
    tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1)
    return transform(tensor)

def validate_image(image: Image.Image):
    """Validate input image constraints"""
    if image.mode not in ["RGB", "RGBA"]:
        raise gr.Error("Only RGB/RGBA images supported")
    if max(image.size) > max(MAX_IMAGE_SIZE):
        raise gr.Error(f"Max image dimension exceeded ({MAX_IMAGE_SIZE[0]}x{MAX_IMAGE_SIZE[1]})")

def enhance_image(
    input_image: Image.Image,
    scale_factor: float = 2.0
) -> Image.Image:
    """Main processing function"""
    try:
        validate_image(input_image)
        
        # Convert RGBA to RGB
        if input_image.mode == 'RGBA':
            input_image = input_image.convert('RGB')
        
        with torch.no_grad():
            input_tensor = preprocess_image(input_image)
            output_tensor = model(input_tensor)
        
        result = postprocess_image(output_tensor)
        return result.resize(
            (int(input_image.width*scale_factor), 
             int(input_image.height*scale_factor)),
            Image.LANCZOS
        )
        
    except Exception as e:
        raise gr.Error(f"Processing error: {str(e)}")

# Initialize model
model = load_model()

# Gradio interface
interface = gr.Interface(
    fn=enhance_image,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Slider(2.0, 4.0, 2.0, step=2.0, label="Scale Factor")
    ],
    outputs=gr.Image(type="pil", label="Enhanced Image"),
    title="🎨 AI Image Enhancer",
    examples=[["examples/example1.jpg", 2.0]],
    css=".gradio-container {max-width: 800px !important}"
)

if __name__ == "__main__":
    interface.launch(server_name="0.0.0.0")