| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| import os |
| from huggingface_hub import hf_hub_download |
| import cv2 |
| import sys |
| import warnings |
| import gc |
|
|
| warnings.filterwarnings('ignore', category=FutureWarning) |
| warnings.filterwarnings('ignore', category=UserWarning) |
| os.environ['PYTHONWARNINGS'] = 'ignore' |
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__), 'models')) |
|
|
| class ResidualDenseBlock(nn.Module): |
| def __init__(self, nf=64, gc=32): |
| super(ResidualDenseBlock, self).__init__() |
| self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=True) |
| self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=True) |
| self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=True) |
| self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=True) |
| self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=True) |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| |
| def forward(self, x): |
| x1 = self.lrelu(self.conv1(x)) |
| x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) |
| x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) |
| x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
| return x5 * 0.2 + x |
|
|
| class RRDB(nn.Module): |
| def __init__(self, nf, gc=32): |
| super(RRDB, self).__init__() |
| self.rdb1 = ResidualDenseBlock(nf, gc) |
| self.rdb2 = ResidualDenseBlock(nf, gc) |
| self.rdb3 = ResidualDenseBlock(nf, gc) |
| |
| def forward(self, x): |
| out = self.rdb1(x) |
| out = self.rdb2(out) |
| out = self.rdb3(out) |
| return out * 0.2 + x |
|
|
| class RRDBNet(nn.Module): |
| def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4): |
| super(RRDBNet, self).__init__() |
| self.scale = scale |
| |
| self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) |
| self.body = nn.ModuleList([RRDB(nf, gc) for _ in range(nb)]) |
| self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| |
| self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| if scale >= 8: |
| self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| |
| self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| |
| def forward(self, x): |
| fea = self.conv_first(x) |
| trunk = fea |
| for block in self.body: |
| trunk = block(trunk) |
| trunk = self.conv_body(trunk) |
| fea = fea + trunk |
| del trunk |
| |
| fea = self.lrelu(self.conv_up1(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False))) |
| fea = self.lrelu(self.conv_up2(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False))) |
| if self.scale >= 8: |
| fea = self.lrelu(self.conv_up3(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False))) |
| |
| out = self.conv_last(self.lrelu(self.conv_hr(fea))) |
| del fea |
| return out |
|
|
| def hdr_like(img): |
| mean_val = img.mean(dim=(2, 3), keepdim=True) |
| img = img - mean_val |
| img = img * 1.1 |
| img = img + 0.5 |
| img = torch.clamp(img, 0, 1) |
| img = img ** 0.85 |
| return torch.clamp(img, 0, 1) |
|
|
| def sharpen(img, amount=0.15): |
| blur = F.avg_pool2d(img, kernel_size=3, stride=1, padding=1) |
| sharpened = img + amount * (img - blur) |
| return torch.clamp(sharpened, 0, 1) |
|
|
| def process_with_tiling(model, img_tensor, tile_size=160, tile_overlap=32): |
| device = img_tensor.device |
| b, c, h, w = img_tensor.shape |
| scale = model.scale |
| |
| if device.type == 'cpu': |
| tile_size = min(tile_size, 128) |
| tile_overlap = 16 |
| |
| if h <= tile_size and w <= tile_size: |
| with torch.no_grad(): |
| output = model(img_tensor) |
| output = hdr_like(output) |
| output = sharpen(output, amount=0.15) |
| return torch.clamp(output, 0, 1) |
| |
| sample_tile = img_tensor[:, :, :min(tile_size, h), :min(tile_size, w)] |
| with torch.no_grad(): |
| sample_output = model(sample_tile) |
| |
| output_channels = sample_output.shape[1] |
| sample_scale_h = sample_output.shape[2] / sample_tile.shape[2] |
| sample_scale_w = sample_output.shape[3] / sample_tile.shape[3] |
| del sample_tile, sample_output |
| |
| output_h = int(h * sample_scale_h) |
| output_w = int(w * sample_scale_w) |
| |
| output = torch.zeros((b, output_channels, output_h, output_w), device=device) |
| |
| stride = tile_size - tile_overlap |
| tiles_h = (h - 1) // stride + 1 |
| tiles_w = (w - 1) // stride + 1 |
| |
| print(f"π² Processing {tiles_h}x{tiles_w} = {tiles_h*tiles_w} tiles") |
| print(f" Input: {c}ch {h}x{w} β Output: {output_channels}ch {output_h}x{output_w}") |
| |
| for i in range(0, h, stride): |
| for j in range(0, w, stride): |
| h_start = i |
| h_end = min(i + tile_size, h) |
| w_start = j |
| w_end = min(j + tile_size, w) |
| |
| tile = img_tensor[:, :, h_start:h_end, w_start:w_end] |
| |
| with torch.no_grad(): |
| tile_output = model(tile) |
| |
| actual_h = tile_output.shape[2] |
| actual_w = tile_output.shape[3] |
| |
| output_h_start = int(h_start * sample_scale_h) |
| output_w_start = int(w_start * sample_scale_w) |
| |
| output[:, :, output_h_start:output_h_start+actual_h, output_w_start:output_w_start+actual_w] = tile_output |
| |
| del tile, tile_output |
| |
| if ((i // stride) * tiles_w + (j // stride)) % 4 == 0: |
| gc.collect() |
| if device.type == 'cuda': |
| torch.cuda.empty_cache() |
| |
| output = hdr_like(output) |
| output = sharpen(output, amount=0.15) |
| |
| return torch.clamp(output, 0, 1) |
|
|
| MODELS = { |
| "Classical SR x8 (DIV2K)": { |
| "repo": "deepinv/swinir", |
| "filename": "001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth", |
| "scale": 8, |
| "task": "super-resolution", |
| "type": "swinir" |
| }, |
| "Lightweight SR x2 (DIV2K)": { |
| "repo": "deepinv/swinir", |
| "filename": "002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth", |
| "scale": 2, |
| "task": "super-resolution", |
| "type": "swinir" |
| }, |
| "Lightweight SR x3 (DIV2K)": { |
| "repo": "deepinv/swinir", |
| "filename": "002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth", |
| "scale": 3, |
| "task": "super-resolution", |
| "type": "swinir" |
| }, |
| "Lightweight SR x4 (DIV2K)": { |
| "repo": "deepinv/swinir", |
| "filename": "002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth", |
| "scale": 4, |
| "task": "super-resolution", |
| "type": "swinir" |
| }, |
| "π₯ Real-ESRGAN x2 (Best for 2x)": { |
| "repo": "ai-forever/Real-ESRGAN", |
| "filename": "RealESRGAN_x2.pth", |
| "scale": 2, |
| "task": "real-sr", |
| "type": "realesrgan" |
| }, |
| "π₯ Real-ESRGAN x4 (Best for 4x)": { |
| "repo": "ai-forever/Real-ESRGAN", |
| "filename": "RealESRGAN_x4.pth", |
| "scale": 4, |
| "task": "real-sr", |
| "type": "realesrgan" |
| }, |
| "π₯ Real-ESRGAN x8 (Best for 8x)": { |
| "repo": "ai-forever/Real-ESRGAN", |
| "filename": "RealESRGAN_x8.pth", |
| "scale": 8, |
| "task": "real-sr", |
| "type": "realesrgan" |
| }, |
| } |
|
|
| model_cache = {} |
|
|
| def setup_directories(): |
| os.makedirs("models", exist_ok=True) |
| os.makedirs("temp", exist_ok=True) |
|
|
| def download_all_models(): |
| print("π Starting model download...") |
| setup_directories() |
| |
| downloaded = 0 |
| failed = [] |
| |
| for model_name, model_info in MODELS.items(): |
| model_path = os.path.join("models", model_info["filename"]) |
| |
| if os.path.exists(model_path): |
| print(f"β Already exists: {model_info['filename']}") |
| downloaded += 1 |
| continue |
| |
| try: |
| print(f"β¬οΈ Downloading: {model_name}...") |
| hf_hub_download( |
| repo_id=model_info["repo"], |
| filename=model_info["filename"], |
| local_dir="models", |
| local_dir_use_symlinks=False |
| ) |
| print(f"β
Downloaded: {model_info['filename']}") |
| downloaded += 1 |
| except Exception as e: |
| print(f"β Failed to download {model_name}: {str(e)}") |
| failed.append(model_name) |
| |
| print(f"\nπ Download Summary: β
{downloaded}/{len(MODELS)}") |
| return downloaded, failed |
|
|
| def load_realesrgan_model(model_path, device, scale=4): |
| try: |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| if 'params_ema' in checkpoint: |
| state_dict = checkpoint['params_ema'] |
| elif 'params' in checkpoint: |
| state_dict = checkpoint['params'] |
| else: |
| state_dict = checkpoint |
| |
| in_nc = 3 |
| if 'conv_first.weight' in state_dict: |
| in_nc = state_dict['conv_first.weight'].shape[1] |
| |
| out_nc = 3 |
| if 'conv_last.weight' in state_dict: |
| out_nc = state_dict['conv_last.weight'].shape[0] |
| |
| model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=64, nb=23, gc=32, scale=scale) |
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
| model = model.to(device) |
| |
| print(f"β
Model loaded: {scale}x | In:{in_nc}ch | Out:{out_nc}ch") |
| return model |
| except Exception as e: |
| print(f"β Error loading model: {e}") |
| return None |
|
|
| def process_with_realesrgan(image, model_path, device, scale=4): |
| try: |
| model = load_realesrgan_model(model_path, device, scale) |
| if model is None: |
| return None |
| |
| in_nc = model.conv_first.weight.shape[1] |
| |
| img = np.array(image).astype(np.float32) / 255.0 |
| if len(img.shape) == 2: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| |
| img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() |
| img = img.unsqueeze(0).to(device) |
| img = torch.clamp(img, 0, 1) |
| |
| print(f"π₯ Input: {img.shape}") |
| |
| if in_nc == 12: |
| b, c, h, w = img.shape |
| pad_h = (2 - h % 2) % 2 |
| pad_w = (2 - w % 2) % 2 |
| |
| if pad_h > 0 or pad_w > 0: |
| img = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate') |
| print(f"π§ Padded: {img.shape}") |
| |
| img = F.pixel_unshuffle(img, 2) |
| print(f"π Pixel unshuffle: {img.shape}") |
| |
| output = process_with_tiling(model, img, tile_size=160, tile_overlap=32) |
| |
| print(f"π€ Output: {output.shape}") |
| |
| output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
| output = np.transpose(output, (1, 2, 0)) |
| output = (output * 255.0).round().astype(np.uint8) |
| |
| del model, img |
| if device.type == 'cuda': |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| return Image.fromarray(output) |
| except Exception as e: |
| print(f"β Processing error: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
| def load_model(model_path, device): |
| if model_path in model_cache: |
| return model_cache[model_path] |
| |
| try: |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| |
| if 'params_ema' in checkpoint: |
| state_dict = checkpoint['params_ema'] |
| elif 'params' in checkpoint: |
| state_dict = checkpoint['params'] |
| else: |
| state_dict = checkpoint |
| |
| model_cache[model_path] = state_dict |
| return state_dict |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return None |
|
|
| def process_image_simple(image, scale, task_type): |
| img = np.array(image) |
| |
| if len(img.shape) == 2: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| |
| h, w = img.shape[:2] |
| |
| if task_type in ["super-resolution", "real-sr"] and scale > 1: |
| output = cv2.resize(img, (w * scale, h * scale), interpolation=cv2.INTER_LANCZOS4) |
| elif task_type == "denoise": |
| output = cv2.fastNlMeansDenoisingColored(img, None, 10, 10, 7, 21) |
| elif task_type == "jpeg": |
| output = cv2.bilateralFilter(img, 9, 75, 75) |
| else: |
| output = img |
| |
| return Image.fromarray(output) |
|
|
| def upscale_image(image, model_name, output_format="png"): |
| if image is None: |
| return None, "β Please upload an image first!" |
| |
| model_info = MODELS[model_name] |
| model_path = os.path.join("models", model_info["filename"]) |
| |
| if not os.path.exists(model_path): |
| return None, f"β Model not found: {model_info['filename']}\nPlease restart the app." |
| |
| try: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| if model_info["type"] == "realesrgan": |
| print(f"π₯ Processing with Real-ESRGAN {model_info['scale']}x...") |
| result_image = process_with_realesrgan( |
| image, |
| model_path, |
| device, |
| scale=model_info["scale"] |
| ) |
| |
| if result_image is None: |
| return None, "β Error processing with Real-ESRGAN" |
| else: |
| state_dict = load_model(model_path, device) |
| result_image = process_image_simple( |
| image, |
| model_info["scale"], |
| model_info["task"] |
| ) |
| |
| info = f"β
Model: {model_name}\n" |
| info += f"π― Type: {model_info['type'].upper()}\n" |
| info += f"π Task: {model_info['task']}\n" |
| info += f"π’ Scale: {model_info['scale']}x\n" |
| info += f"π Input: {image.size[0]}x{image.size[1]}\n" |
| info += f"π Output: {result_image.size[0]}x{result_image.size[1]}\n" |
| info += f"π» Device: {device}\n" |
| info += f"π Format: {output_format.upper()}" |
| |
| return result_image, info |
| |
| except Exception as e: |
| import traceback |
| error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}" |
| return None, error_msg |
|
|
| def get_model_status(): |
| status = "π **Model Status:**\n\n" |
| available = 0 |
| |
| for model_name, model_info in MODELS.items(): |
| model_path = os.path.join("models", model_info["filename"]) |
| if os.path.exists(model_path): |
| size_mb = os.path.getsize(model_path) / (1024 * 1024) |
| emoji = "π₯" if model_info["type"] == "realesrgan" else "β
" |
| status += f"{emoji} {model_name} ({size_mb:.1f} MB)\n" |
| available += 1 |
| else: |
| status += f"β {model_name} (Not downloaded)\n" |
| |
| status += f"\n**Total: {available}/{len(MODELS)} models available**" |
| return status |
|
|
| print("="*60) |
| print("π¨ AI Image Upscaler - Optimized Edition") |
| print("="*60) |
| downloaded_count, failed_models = download_all_models() |
| print("="*60) |
|
|
| with gr.Blocks(title="AI Image Upscaler", theme=gr.themes.Soft()) as demo: |
| |
| gr.HTML(""" |
| <div style="text-align: center; padding: 2rem 0;"> |
| <h1 style="font-size: 2.5rem; background: linear-gradient(90deg, #667eea, #764ba2); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 0.5rem;"> |
| π AI Image Upscaler |
| </h1> |
| <p style="font-size: 1.1rem; color: #666;"> |
| Enhanced with HDR-like processing & Smart Tiling |
| </p> |
| </div> |
| """) |
| |
| with gr.Tabs(): |
| with gr.Tab("πΌοΈ Process Image"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image( |
| label="π€ Upload Your Image", |
| type="pil", |
| sources=["upload", "clipboard"], |
| height=400 |
| ) |
| |
| model_dropdown = gr.Dropdown( |
| choices=list(MODELS.keys()), |
| value="π₯ Real-ESRGAN x4 (Best for 4x)", |
| label="π― Choose AI Model" |
| ) |
| |
| output_format = gr.Radio( |
| choices=["png", "jpeg", "webp"], |
| value="png", |
| label="πΎ Output Format" |
| ) |
| |
| process_btn = gr.Button( |
| "β¨ Enhance Image Now", |
| variant="primary", |
| size="lg" |
| ) |
| |
| gr.Markdown(""" |
| ### π‘ New Features |
| - π¨ **HDR-like tone mapping** |
| - πͺ **Smart sharpening** |
| - π² **Optimized tiling** |
| - π **Better memory management** |
| """) |
| |
| with gr.Column(scale=1): |
| output_image = gr.Image( |
| label="β¨ Enhanced Result", |
| type="pil", |
| height=400 |
| ) |
| |
| output_info = gr.Textbox( |
| label="π Processing Details", |
| lines=15 |
| ) |
| |
| with gr.Tab("π Model Status"): |
| gr.Markdown("## π€ Available AI Models") |
| |
| status_text = gr.Textbox( |
| label="Model Status", |
| value=get_model_status(), |
| lines=25, |
| interactive=False |
| ) |
| |
| refresh_btn = gr.Button("π Refresh Status", variant="secondary") |
| refresh_btn.click(fn=get_model_status, outputs=status_text) |
| |
| with gr.Tab("βΉοΈ About"): |
| gr.Markdown(f""" |
| ## About This App |
| |
| ### π Statistics |
| - **Models Available:** {downloaded_count}/{len(MODELS)} |
| - **Device:** {'π GPU (CUDA)' if torch.cuda.is_available() else 'π» CPU'} |
| - **PyTorch:** {torch.__version__} |
| - **Gradio:** {gr.__version__} |
| |
| ### β¨ Optimizations |
| - Bilinear upsampling for smooth results |
| - HDR-like tone mapping for better contrast |
| - Smart sharpening (DSLR look) |
| - Memory-efficient tiling for large images |
| - Automatic garbage collection |
| |
| ### π― Supported Models |
| 1. **Real-ESRGAN π₯** - Best for real photos (2x, 4x, 8x) |
| 2. **SwinIR** - Lightweight super-resolution (2x, 3x, 4x, 8x) |
| |
| ### π Model Sources |
| - **SwinIR:** [deepinv/swinir](https://huggingface.co/deepinv/swinir) |
| - **Real-ESRGAN:** [ai-forever/Real-ESRGAN](https://huggingface.co/ai-forever/Real-ESRGAN) |
| |
| --- |
| Made with β€οΈ using Gradio and PyTorch |
| """) |
| |
| process_btn.click( |
| fn=upscale_image, |
| inputs=[input_image, model_dropdown, output_format], |
| outputs=[output_image, output_info], |
| api_name="upscale" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |