Spaces:
Sleeping
Sleeping
File size: 3,748 Bytes
1545518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image, ImageFilter
import os
import time
import gc
from RRDBNet_arch import RRDBNet # Ensure this file is in the same directory
# -------------------------
# Load ESRGAN Model (from root directory)
# -------------------------
@torch.no_grad()
def load_model():
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
model_path = "RRDB_ESRGAN_x4.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
model.eval()
return model
model = load_model()
# -------------------------
# Utility Functions
# -------------------------
def preprocess(img_pil):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
return transform(img_pil).unsqueeze(0)
def postprocess(tensor):
tensor = tensor.squeeze().detach().cpu()
tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
return transforms.ToPILImage()(tensor)
def fuse_images(img1, img2):
img1 = img1.resize((384, 384), Image.LANCZOS)
img2 = img2.resize((384, 384), Image.LANCZOS)
return Image.blend(img1, img2, alpha=0.5)
def sharpen_image(image: Image.Image) -> Image.Image:
return image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=150, threshold=1))
def upscale_to_8k(img: Image.Image) -> Image.Image:
return img.resize((8000, 8000), Image.LANCZOS)
# -------------------------
# Inference Pipeline
# -------------------------
def esrgan_pipeline(img1, img2, _):
if not img1 or not img2:
return None, None, "Please upload two valid images."
img1 = img1.convert("RGB")
img2 = img2.convert("RGB")
fused_img = fuse_images(img1, img2)
start = time.time()
with torch.no_grad():
input_tensor = preprocess(fused_img)
sr_output = model(input_tensor)
base_output = postprocess(sr_output)
gc.collect()
torch.cuda.empty_cache()
upscaled_img = upscale_to_8k(base_output)
final_img = sharpen_image(upscaled_img)
elapsed = time.time() - start
sharpness_score = torch.var(torch.tensor(base_output.convert("L"))).item()
msg = f"β
Done in {elapsed:.2f}s | Sharpness: {sharpness_score:.2f}"
return base_output, final_img, msg
# -------------------------
# Gradio UI
# -------------------------
with gr.Blocks(title="8000x8000 ESRGAN Ultra-HD Super-Resolution") as demo:
gr.Markdown("## π§ ESRGAN Ultra-HD Image Upscaler (8000 Γ 8000 Output)")
gr.Markdown("Upload **two low-res images** β Fuse β ESRGAN β Final **8000 Γ 8000** enhanced image with sharpening.")
with gr.Row():
with gr.Column():
img_input1 = gr.Image(type="pil", label="Low-Res Image 1")
img_input2 = gr.Image(type="pil", label="Low-Res Image 2")
dummy_resolution = gr.Radio(["8000x8000"], value="8000x8000", label="Output Resolution (Fixed)")
run_button = gr.Button("π Run ESRGAN")
with gr.Column():
output_esrgan = gr.Image(label="π§ ESRGAN Output")
output_final = gr.Image(label="ποΈ Final Enhanced Output (8000 Γ 8000)")
result_text = gr.Textbox(label="π Output Log")
gr.Markdown("---")
gr.HTML(
"<div style='text-align: center; font-size: 16px;'>"
"Made with β€οΈ by <b>CodeKarma</b> as a part of <b>Bharatiya Antariksh Hackathon 2025</b>"
"</div>"
)
run_button.click(fn=esrgan_pipeline,
inputs=[img_input1, img_input2, dummy_resolution],
outputs=[output_esrgan, output_final, result_text])
# -------------------------
# Launch
# -------------------------
if __name__ == "__main__":
demo.launch()
|