upscaler / app.py
hann1010's picture
Update app.py
291dbb1 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
from huggingface_hub import hf_hub_download
import cv2
import sys
import warnings
import gc
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['PYTHONWARNINGS'] = 'ignore'
sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
class ResidualDenseBlock(nn.Module):
def __init__(self, nf=64, gc=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=True)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=True)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=True)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(nf, gc)
self.rdb2 = ResidualDenseBlock(nf, gc)
self.rdb3 = ResidualDenseBlock(nf, gc)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4):
super(RRDBNet, self).__init__()
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.body = nn.ModuleList([RRDB(nf, gc) for _ in range(nb)])
self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if scale >= 8:
self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = fea
for block in self.body:
trunk = block(trunk)
trunk = self.conv_body(trunk)
fea = fea + trunk
del trunk
fea = self.lrelu(self.conv_up1(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False)))
fea = self.lrelu(self.conv_up2(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False)))
if self.scale >= 8:
fea = self.lrelu(self.conv_up3(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False)))
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
del fea
return out
def hdr_like(img):
mean_val = img.mean(dim=(2, 3), keepdim=True)
img = img - mean_val
img = img * 1.1
img = img + 0.5
img = torch.clamp(img, 0, 1)
img = img ** 0.85
return torch.clamp(img, 0, 1)
def sharpen(img, amount=0.15):
blur = F.avg_pool2d(img, kernel_size=3, stride=1, padding=1)
sharpened = img + amount * (img - blur)
return torch.clamp(sharpened, 0, 1)
def process_with_tiling(model, img_tensor, tile_size=160, tile_overlap=32):
device = img_tensor.device
b, c, h, w = img_tensor.shape
scale = model.scale
if device.type == 'cpu':
tile_size = min(tile_size, 128)
tile_overlap = 16
if h <= tile_size and w <= tile_size:
with torch.no_grad():
output = model(img_tensor)
output = hdr_like(output)
output = sharpen(output, amount=0.15)
return torch.clamp(output, 0, 1)
sample_tile = img_tensor[:, :, :min(tile_size, h), :min(tile_size, w)]
with torch.no_grad():
sample_output = model(sample_tile)
output_channels = sample_output.shape[1]
sample_scale_h = sample_output.shape[2] / sample_tile.shape[2]
sample_scale_w = sample_output.shape[3] / sample_tile.shape[3]
del sample_tile, sample_output
output_h = int(h * sample_scale_h)
output_w = int(w * sample_scale_w)
output = torch.zeros((b, output_channels, output_h, output_w), device=device)
stride = tile_size - tile_overlap
tiles_h = (h - 1) // stride + 1
tiles_w = (w - 1) // stride + 1
print(f"πŸ”² Processing {tiles_h}x{tiles_w} = {tiles_h*tiles_w} tiles")
print(f" Input: {c}ch {h}x{w} β†’ Output: {output_channels}ch {output_h}x{output_w}")
for i in range(0, h, stride):
for j in range(0, w, stride):
h_start = i
h_end = min(i + tile_size, h)
w_start = j
w_end = min(j + tile_size, w)
tile = img_tensor[:, :, h_start:h_end, w_start:w_end]
with torch.no_grad():
tile_output = model(tile)
actual_h = tile_output.shape[2]
actual_w = tile_output.shape[3]
output_h_start = int(h_start * sample_scale_h)
output_w_start = int(w_start * sample_scale_w)
output[:, :, output_h_start:output_h_start+actual_h, output_w_start:output_w_start+actual_w] = tile_output
del tile, tile_output
if ((i // stride) * tiles_w + (j // stride)) % 4 == 0:
gc.collect()
if device.type == 'cuda':
torch.cuda.empty_cache()
output = hdr_like(output)
output = sharpen(output, amount=0.15)
return torch.clamp(output, 0, 1)
MODELS = {
"Classical SR x8 (DIV2K)": {
"repo": "deepinv/swinir",
"filename": "001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth",
"scale": 8,
"task": "super-resolution",
"type": "swinir"
},
"Lightweight SR x2 (DIV2K)": {
"repo": "deepinv/swinir",
"filename": "002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth",
"scale": 2,
"task": "super-resolution",
"type": "swinir"
},
"Lightweight SR x3 (DIV2K)": {
"repo": "deepinv/swinir",
"filename": "002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth",
"scale": 3,
"task": "super-resolution",
"type": "swinir"
},
"Lightweight SR x4 (DIV2K)": {
"repo": "deepinv/swinir",
"filename": "002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth",
"scale": 4,
"task": "super-resolution",
"type": "swinir"
},
"πŸ”₯ Real-ESRGAN x2 (Best for 2x)": {
"repo": "ai-forever/Real-ESRGAN",
"filename": "RealESRGAN_x2.pth",
"scale": 2,
"task": "real-sr",
"type": "realesrgan"
},
"πŸ”₯ Real-ESRGAN x4 (Best for 4x)": {
"repo": "ai-forever/Real-ESRGAN",
"filename": "RealESRGAN_x4.pth",
"scale": 4,
"task": "real-sr",
"type": "realesrgan"
},
"πŸ”₯ Real-ESRGAN x8 (Best for 8x)": {
"repo": "ai-forever/Real-ESRGAN",
"filename": "RealESRGAN_x8.pth",
"scale": 8,
"task": "real-sr",
"type": "realesrgan"
},
}
model_cache = {}
def setup_directories():
os.makedirs("models", exist_ok=True)
os.makedirs("temp", exist_ok=True)
def download_all_models():
print("πŸš€ Starting model download...")
setup_directories()
downloaded = 0
failed = []
for model_name, model_info in MODELS.items():
model_path = os.path.join("models", model_info["filename"])
if os.path.exists(model_path):
print(f"βœ“ Already exists: {model_info['filename']}")
downloaded += 1
continue
try:
print(f"⬇️ Downloading: {model_name}...")
hf_hub_download(
repo_id=model_info["repo"],
filename=model_info["filename"],
local_dir="models",
local_dir_use_symlinks=False
)
print(f"βœ… Downloaded: {model_info['filename']}")
downloaded += 1
except Exception as e:
print(f"❌ Failed to download {model_name}: {str(e)}")
failed.append(model_name)
print(f"\nπŸ“Š Download Summary: βœ… {downloaded}/{len(MODELS)}")
return downloaded, failed
def load_realesrgan_model(model_path, device, scale=4):
try:
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if 'params_ema' in checkpoint:
state_dict = checkpoint['params_ema']
elif 'params' in checkpoint:
state_dict = checkpoint['params']
else:
state_dict = checkpoint
in_nc = 3
if 'conv_first.weight' in state_dict:
in_nc = state_dict['conv_first.weight'].shape[1]
out_nc = 3
if 'conv_last.weight' in state_dict:
out_nc = state_dict['conv_last.weight'].shape[0]
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=64, nb=23, gc=32, scale=scale)
model.load_state_dict(state_dict, strict=True)
model.eval()
model = model.to(device)
print(f"βœ… Model loaded: {scale}x | In:{in_nc}ch | Out:{out_nc}ch")
return model
except Exception as e:
print(f"❌ Error loading model: {e}")
return None
def process_with_realesrgan(image, model_path, device, scale=4):
try:
model = load_realesrgan_model(model_path, device, scale)
if model is None:
return None
in_nc = model.conv_first.weight.shape[1]
img = np.array(image).astype(np.float32) / 255.0
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
img = torch.clamp(img, 0, 1)
print(f"πŸ“₯ Input: {img.shape}")
if in_nc == 12:
b, c, h, w = img.shape
pad_h = (2 - h % 2) % 2
pad_w = (2 - w % 2) % 2
if pad_h > 0 or pad_w > 0:
img = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate')
print(f"πŸ”§ Padded: {img.shape}")
img = F.pixel_unshuffle(img, 2)
print(f"πŸ”„ Pixel unshuffle: {img.shape}")
output = process_with_tiling(model, img, tile_size=160, tile_overlap=32)
print(f"πŸ“€ Output: {output.shape}")
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
del model, img
if device.type == 'cuda':
torch.cuda.empty_cache()
gc.collect()
return Image.fromarray(output)
except Exception as e:
print(f"❌ Processing error: {e}")
import traceback
traceback.print_exc()
return None
def load_model(model_path, device):
if model_path in model_cache:
return model_cache[model_path]
try:
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if 'params_ema' in checkpoint:
state_dict = checkpoint['params_ema']
elif 'params' in checkpoint:
state_dict = checkpoint['params']
else:
state_dict = checkpoint
model_cache[model_path] = state_dict
return state_dict
except Exception as e:
print(f"Error loading model: {e}")
return None
def process_image_simple(image, scale, task_type):
img = np.array(image)
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
h, w = img.shape[:2]
if task_type in ["super-resolution", "real-sr"] and scale > 1:
output = cv2.resize(img, (w * scale, h * scale), interpolation=cv2.INTER_LANCZOS4)
elif task_type == "denoise":
output = cv2.fastNlMeansDenoisingColored(img, None, 10, 10, 7, 21)
elif task_type == "jpeg":
output = cv2.bilateralFilter(img, 9, 75, 75)
else:
output = img
return Image.fromarray(output)
def upscale_image(image, model_name, output_format="png"):
if image is None:
return None, "❌ Please upload an image first!"
model_info = MODELS[model_name]
model_path = os.path.join("models", model_info["filename"])
if not os.path.exists(model_path):
return None, f"❌ Model not found: {model_info['filename']}\nPlease restart the app."
try:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if model_info["type"] == "realesrgan":
print(f"πŸ”₯ Processing with Real-ESRGAN {model_info['scale']}x...")
result_image = process_with_realesrgan(
image,
model_path,
device,
scale=model_info["scale"]
)
if result_image is None:
return None, "❌ Error processing with Real-ESRGAN"
else:
state_dict = load_model(model_path, device)
result_image = process_image_simple(
image,
model_info["scale"],
model_info["task"]
)
info = f"βœ… Model: {model_name}\n"
info += f"🎯 Type: {model_info['type'].upper()}\n"
info += f"πŸ“Š Task: {model_info['task']}\n"
info += f"πŸ”’ Scale: {model_info['scale']}x\n"
info += f"πŸ“ Input: {image.size[0]}x{image.size[1]}\n"
info += f"πŸ“ Output: {result_image.size[0]}x{result_image.size[1]}\n"
info += f"πŸ’» Device: {device}\n"
info += f"πŸ“ Format: {output_format.upper()}"
return result_image, info
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
return None, error_msg
def get_model_status():
status = "πŸ“Š **Model Status:**\n\n"
available = 0
for model_name, model_info in MODELS.items():
model_path = os.path.join("models", model_info["filename"])
if os.path.exists(model_path):
size_mb = os.path.getsize(model_path) / (1024 * 1024)
emoji = "πŸ”₯" if model_info["type"] == "realesrgan" else "βœ…"
status += f"{emoji} {model_name} ({size_mb:.1f} MB)\n"
available += 1
else:
status += f"❌ {model_name} (Not downloaded)\n"
status += f"\n**Total: {available}/{len(MODELS)} models available**"
return status
print("="*60)
print("🎨 AI Image Upscaler - Optimized Edition")
print("="*60)
downloaded_count, failed_models = download_all_models()
print("="*60)
with gr.Blocks(title="AI Image Upscaler", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style="text-align: center; padding: 2rem 0;">
<h1 style="font-size: 2.5rem; background: linear-gradient(90deg, #667eea, #764ba2); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 0.5rem;">
πŸš€ AI Image Upscaler
</h1>
<p style="font-size: 1.1rem; color: #666;">
Enhanced with HDR-like processing & Smart Tiling
</p>
</div>
""")
with gr.Tabs():
with gr.Tab("πŸ–ΌοΈ Process Image"):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="πŸ“€ Upload Your Image",
type="pil",
sources=["upload", "clipboard"],
height=400
)
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="πŸ”₯ Real-ESRGAN x4 (Best for 4x)",
label="🎯 Choose AI Model"
)
output_format = gr.Radio(
choices=["png", "jpeg", "webp"],
value="png",
label="πŸ’Ύ Output Format"
)
process_btn = gr.Button(
"✨ Enhance Image Now",
variant="primary",
size="lg"
)
gr.Markdown("""
### πŸ’‘ New Features
- 🎨 **HDR-like tone mapping**
- πŸ”ͺ **Smart sharpening**
- πŸ”² **Optimized tiling**
- πŸš€ **Better memory management**
""")
with gr.Column(scale=1):
output_image = gr.Image(
label="✨ Enhanced Result",
type="pil",
height=400
)
output_info = gr.Textbox(
label="πŸ“Š Processing Details",
lines=15
)
with gr.Tab("πŸ“Š Model Status"):
gr.Markdown("## πŸ€– Available AI Models")
status_text = gr.Textbox(
label="Model Status",
value=get_model_status(),
lines=25,
interactive=False
)
refresh_btn = gr.Button("πŸ”„ Refresh Status", variant="secondary")
refresh_btn.click(fn=get_model_status, outputs=status_text)
with gr.Tab("ℹ️ About"):
gr.Markdown(f"""
## About This App
### πŸ“ˆ Statistics
- **Models Available:** {downloaded_count}/{len(MODELS)}
- **Device:** {'πŸš€ GPU (CUDA)' if torch.cuda.is_available() else 'πŸ’» CPU'}
- **PyTorch:** {torch.__version__}
- **Gradio:** {gr.__version__}
### ✨ Optimizations
- Bilinear upsampling for smooth results
- HDR-like tone mapping for better contrast
- Smart sharpening (DSLR look)
- Memory-efficient tiling for large images
- Automatic garbage collection
### 🎯 Supported Models
1. **Real-ESRGAN πŸ”₯** - Best for real photos (2x, 4x, 8x)
2. **SwinIR** - Lightweight super-resolution (2x, 3x, 4x, 8x)
### πŸ“š Model Sources
- **SwinIR:** [deepinv/swinir](https://huggingface.co/deepinv/swinir)
- **Real-ESRGAN:** [ai-forever/Real-ESRGAN](https://huggingface.co/ai-forever/Real-ESRGAN)
---
Made with ❀️ using Gradio and PyTorch
""")
process_btn.click(
fn=upscale_image,
inputs=[input_image, model_dropdown, output_format],
outputs=[output_image, output_info],
api_name="upscale"
)
if __name__ == "__main__":
demo.launch()