File size: 7,356 Bytes
e3fd7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0ea2eb
e3fd7d4
 
f0ea2eb
e3fd7d4
 
 
 
f0ea2eb
 
 
 
 
 
 
 
 
 
e3fd7d4
d4a7e1c
e3fd7d4
 
 
3800d7b
e3fd7d4
f0ea2eb
 
e3fd7d4
 
 
 
 
 
 
 
 
 
f0ea2eb
 
 
 
3800d7b
 
f0ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
d4a7e1c
 
 
 
f0ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3800d7b
 
 
f0ea2eb
 
0867bd7
f0ea2eb
e3fd7d4
 
1f0e31a
e3fd7d4
 
 
d4a7e1c
3800d7b
 
 
 
 
 
 
 
 
 
 
 
e3fd7d4
 
 
f0ea2eb
 
 
d4a7e1c
f0ea2eb
 
 
 
 
 
 
 
 
 
 
 
3800d7b
f0ea2eb
 
 
 
 
 
 
0867bd7
f0ea2eb
 
0867bd7
f0ea2eb
 
e3fd7d4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from PIL import Image
import torchvision.transforms.functional as TF

# --- 1. MODEL ARCHITECTURE ---
class PureResBlock(nn.Module):
    def __init__(self, channels):
        super(PureResBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.act = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.res_scale = 1.0 

    def forward(self, x):
        res = self.conv1(x)
        res = self.act(res)
        res = self.conv2(res)
        return x + (res * self.res_scale)

class FastEDSR(nn.Module):
    def __init__(self, scale_factor=2, num_blocks=8, channels=64):
        super(FastEDSR, self).__init__()
        self.scale_factor = scale_factor
        
        self.head = nn.Conv2d(3, channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.body = nn.Sequential(*[PureResBlock(channels) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.sub_pixel = nn.Sequential(
            nn.Conv2d(channels, 3 * (scale_factor ** 2), kernel_size=3, padding=1, padding_mode='replicate'),
            nn.PixelShuffle(scale_factor)
        )

    def forward(self, x):
        base_upscaled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        f0 = self.head(x)
        f_body = self.body(f0)
        f_body = self.tail(f_body)
        f_out = f0 + f_body
        details = self.sub_pixel(f_out)
        return base_upscaled + details

# --- 2. INITIALIZATION ---
device = torch.device('cpu') 
model = FastEDSR(scale_factor=2, num_blocks=8, channels=64)

# Load the weights
model_path = "FastEDSR_x2_31dB.pth" 
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

def calc_psnr(pred, target):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return 100.0
    return 10 * torch.log10(1.0 / mse).item()

# --- 3. INFERENCE FUNCTIONS ---

def standard_upscale(img):
    if img is None: return None, ""
        
    max_input_dim = 2048 
    w, h = img.size
    
    if w > max_input_dim or h > max_input_dim:
        gr.Warning(f"Input image exceeded the 2K ({max_input_dim}px) limit. It has been proportionally downscaled to ensure the 4K output fits in server memory and constraints.")
        scale = max_input_dim / max(w, h)
        w, h = int(w * scale), int(h * scale)
        img = img.resize((w, h), Image.BICUBIC)

    img = img.convert('RGB')
    input_tensor = TF.to_tensor(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output_tensor = model(input_tensor)

    output_tensor = output_tensor.squeeze(0).clamp(0, 1)
    output_img = TF.to_pil_image(output_tensor)
    
    new_w, new_h = output_img.size
    
    details = (
        f"### Resolution Details\n"
        f"- **Before:** {w} x {h} ({w * h:,} pixels)\n\n"
        f"- **After:** {new_w} x {new_h} ({new_w * new_h:,} pixels)"
    )
    
    return output_img, details

def benchmark_upscale(hr_img):
    if hr_img is None: return "", None, None
    
    hr_img = hr_img.convert('RGB')
    w, h = hr_img.size
    
    w = w - (w % 2)
    h = h - (h % 2)
    hr_img = hr_img.crop((0, 0, w, h))
    
    max_hr_dim = 4096 
    if w > max_hr_dim or h > max_hr_dim:
        gr.Warning(f"Ground truth image exceeded the 4K ({max_hr_dim}px) limit. It has been proportionally downscaled.")
        scale = max_hr_dim / max(w, h)
        w, h = int(w * scale), int(h * scale)
        w = w - (w % 2)
        h = h - (h % 2)
        hr_img = hr_img.resize((w, h), Image.BICUBIC)

    lr_w, lr_h = w // 2, h // 2
    lr_img = hr_img.resize((lr_w, lr_h), Image.BICUBIC)
    
    lr_tensor = TF.to_tensor(lr_img).unsqueeze(0).to(device)
    hr_tensor = TF.to_tensor(hr_img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred_tensor = model(lr_tensor).clamp(0, 1)

    psnr = calc_psnr(pred_tensor, hr_tensor)
    pred_img = TF.to_pil_image(pred_tensor.squeeze(0))
    
    lr_slider_img = lr_img.resize((w, h), Image.NEAREST)
    
    details = (
        f"### Benchmark Results\n"
        f"- **PSNR:** {psnr:.2f} dB\n\n"
        f"- **Low-Res Input:** {lr_w} x {lr_h} ({lr_w * lr_h:,} pixels)\n\n"
        f"- **Model Output & Ground Truth:** {w} x {h} ({w * h:,} pixels)"
    )
    
    # Gradio's native ImageSlider expects a tuple of (image1, image2)
    return details, (lr_slider_img, pred_img), (hr_img, pred_img)

# --- 4. GRADIO UI ---
with gr.Blocks() as app:
    gr.Markdown(
        """
        # ⚡ FastEDSR 2x Image Upscaler
        Upload an image to enhance and upscale it by 2x. Supports up to 4K resolution output.

        For more information on the model, training, and use, and for a local demo, visit our [website](https://infinitode.netlify.app).
        
        This model was trained on DIV2K:
        ```
        @inproceedings{Agustsson2017,
          title={NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study},
          author={Agustsson, Eirikur and Timofte, Radu},
          booktitle={IEEE Conference on Computer Vision and Pattern Recognition Workshops},
          year={2017}
        }
        ```
        """
    )
    
    with gr.Tabs():
        # TAB 1: STANDARD
        with gr.TabItem("⚡ Standard Upscaling"):
            gr.Markdown("Directly upscale any low-resolution image. Inputs over 2K (2048px) will be scaled down to prevent memory limits.")
            with gr.Row():
                with gr.Column():
                    std_input = gr.Image(type="pil", label="Low Resolution Input")
                    std_btn = gr.Button("Upscale Image", variant="primary")
                with gr.Column():
                    std_output = gr.Image(type="pil", label="2x High Resolution Output")
                    std_details = gr.Markdown()
                    
            std_btn.click(fn=standard_upscale, inputs=std_input, outputs=[std_output, std_details])

        # TAB 2: BENCHMARK
        with gr.TabItem("📊 Benchmark Mode"):
            gr.Markdown("Upload a high-quality image (up to 4K). The app will compress it to 2x lower resolution, upscale it using FastEDSR, and measure the PSNR quality against the original. It will also generate side-by-side comparisons for you to view.")
            with gr.Row():
                with gr.Column():
                    bm_input = gr.Image(type="pil", label="Ground Truth (High Res) Image")
                    bm_btn = gr.Button("Run Benchmark", variant="primary")
                    bm_details = gr.Markdown()
                with gr.Column():
                    gr.Markdown("### Low-Res vs. Model Prediction")
                    slider_lr_pred = gr.ImageSlider(label="Left: Pixelated Low-Res | Right: FastEDSR")
                    
                    gr.Markdown("### Ground Truth vs. Model Prediction")
                    slider_hr_pred = gr.ImageSlider(label="Left: Original HR | Right: FastEDSR")
                    
            bm_btn.click(fn=benchmark_upscale, inputs=bm_input, outputs=[bm_details, slider_lr_pred, slider_hr_pred])

if __name__ == "__main__":
    app.launch()