dev1461's picture
Create app.py
3dbf5d4 verified
raw
history blame
2.69 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
# ---------------------------
# MODEL ARCHITECTURE
# ---------------------------
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(channels, channels, 3, 1, 1)
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.entry = nn.Conv2d(3, 64, 3, 1, 1)
self.res_blocks = nn.Sequential(
ResidualBlock(64),
ResidualBlock(64),
ResidualBlock(64)
)
self.exit = nn.Sequential(
nn.Conv2d(64, 3, 3, 1, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.entry(x)
x = self.res_blocks(x)
return self.exit(x)
# ---------------------------
# LOAD MODEL
# ---------------------------
device = torch.device("cpu")
model = Generator().to(device)
checkpoint = torch.load("final_sr_model_v3.pth", map_location=device)
model.load_state_dict(checkpoint['generator'])
model.eval()
# ---------------------------
# TRANSFORM
# ---------------------------
transform = transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor()
])
# ---------------------------
# INFERENCE FUNCTION
# ---------------------------
def enhance_image(input_image):
img = input_image.convert("RGB")
input_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
output_img = output.squeeze().permute(1,2,0).cpu().numpy()
output_img = (output_img * 255).astype(np.uint8)
return output_img
with gr.Blocks() as demo:
gr.Markdown("# πŸ” Image Super Resolution")
input_img = gr.Image(type="pil", label="Upload Image")
output_img = gr.Image(type="numpy", label="Enhanced Image")
btn = gr.Button("Enhance Image")
btn.click(fn=enhance_image, inputs=input_img, outputs=output_img)
gr.DownloadButton(label="Download Enhanced Image", data=output_img)
demo.launch()
# ---------------------------
# GRADIO UI
# ---------------------------
interface = gr.Interface(
fn=enhance_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="numpy", label="Enhanced Image"),
title="πŸ” Super Resolution App",
description="Upload a low-quality image and enhance it using deep learning",
allow_flagging="never"
)
interface.launch()