Spaces:
Sleeping
Sleeping
| """ | |
| Compact AI Models for <1GB VRAM Usage | |
| SwinIR Lightweight & Compact Real-ESRGAN | |
| Optimized for RTX 3050 Laptop GPU | |
| """ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, Dict | |
| import math | |
| import requests | |
| from tqdm import tqdm | |
| # Compact SwinIR Implementation | |
| class PatchEmbed(nn.Module): | |
| """Image to Patch Embedding - Compact version""" | |
| def __init__(self, img_size=64, patch_size=1, embed_dim=60): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.num_patches = (img_size // patch_size) ** 2 | |
| self.proj = nn.Conv2d(3, embed_dim, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| return self.proj(x) | |
| class WindowAttention(nn.Module): | |
| """Window based multi-head self attention - Compact version""" | |
| def __init__(self, dim, window_size, num_heads=6): | |
| super().__init__() | |
| self.dim = dim | |
| self.window_size = window_size | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
| self.proj = nn.Linear(dim, dim) | |
| def forward(self, x): | |
| B_, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| q = q * self.scale | |
| attn = (q @ k.transpose(-2, -1)) | |
| attn = attn.softmax(dim=-1) | |
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |
| x = self.proj(x) | |
| return x | |
| class SwinTransformerBlock(nn.Module): | |
| """Swin Transformer Block - Compact version""" | |
| def __init__(self, dim, num_heads, window_size=8, mlp_ratio=2.): | |
| super().__init__() | |
| self.window_size = window_size | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = WindowAttention(dim, window_size, num_heads) | |
| self.norm2 = nn.LayerNorm(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, mlp_hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(mlp_hidden_dim, dim) | |
| ) | |
| def forward(self, x): | |
| H, W = x.shape[2:] | |
| B, C, H, W = x.shape | |
| # Reshape for attention | |
| x_reshaped = x.flatten(2).transpose(1, 2) | |
| # Attention | |
| shortcut = x_reshaped | |
| x_reshaped = self.norm1(x_reshaped) | |
| x_reshaped = self.attn(x_reshaped.unsqueeze(0)).squeeze(0) | |
| x_reshaped = shortcut + x_reshaped | |
| # MLP | |
| shortcut = x_reshaped | |
| x_reshaped = self.norm2(x_reshaped) | |
| x_reshaped = self.mlp(x_reshaped) | |
| x_reshaped = shortcut + x_reshaped | |
| # Reshape back | |
| x = x_reshaped.transpose(1, 2).reshape(B, C, H, W) | |
| return x | |
| class CompactSwinIR(nn.Module): | |
| """Compact SwinIR for <1GB VRAM""" | |
| def __init__(self, upscale=4, img_size=64, window_size=8, | |
| embed_dim=60, depths=[4], num_heads=[6]): | |
| super().__init__() | |
| self.upscale = upscale | |
| self.img_size = img_size | |
| self.window_size = window_size | |
| # Shallow feature extraction | |
| self.conv_first = nn.Conv2d(3, embed_dim, 3, 1, 1) | |
| # Transformer blocks (reduced depth) | |
| self.layers = nn.ModuleList() | |
| for i in range(depths[0]): | |
| self.layers.append( | |
| SwinTransformerBlock(embed_dim, num_heads[0], window_size) | |
| ) | |
| # Reconstruction | |
| self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) | |
| # Upsampling | |
| self.conv_before_upsample = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) | |
| self.upsample = nn.Sequential( | |
| nn.Conv2d(embed_dim, 3 * upscale * upscale, 3, 1, 1), | |
| nn.PixelShuffle(upscale) | |
| ) | |
| def forward(self, x): | |
| # Shallow feature extraction | |
| x = self.conv_first(x) | |
| res = x | |
| # Transformer blocks | |
| for layer in self.layers: | |
| x = layer(x) | |
| # Reconstruction | |
| x = self.conv_after_body(x) | |
| x = x + res | |
| # Upsampling | |
| x = self.conv_before_upsample(x) | |
| x = self.upsample(x) | |
| return x | |
| class CompactRRDBNet(nn.Module): | |
| """Compact RRDB Net for Real-ESRGAN - <1GB VRAM""" | |
| def __init__(self, in_nc=3, out_nc=3, nf=32, nb=6, gc=16): | |
| super().__init__() | |
| # First convolution | |
| self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) | |
| # Compact RRDB blocks (reduced from 23 to 6) | |
| self.RRDB_trunk = nn.Sequential(*[ | |
| self.make_rrdb_block(nf, gc) for _ in range(nb) | |
| ]) | |
| # Trunk convolution | |
| self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| # Upsampling | |
| self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| def make_rrdb_block(self, nf, gc): | |
| """Make a compact RRDB block""" | |
| return nn.Sequential( | |
| nn.Conv2d(nf, gc, 3, 1, 1), | |
| nn.LeakyReLU(0.2, True), | |
| nn.Conv2d(gc, nf, 3, 1, 1) | |
| ) | |
| def forward(self, x): | |
| fea = self.conv_first(x) | |
| trunk = self.trunk_conv(self.RRDB_trunk(fea)) | |
| fea = fea + trunk | |
| fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) | |
| fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) | |
| out = self.conv_last(self.lrelu(self.HRconv(fea))) | |
| return out | |
| class CompactAIEnhancer: | |
| """Compact AI Enhancer using SwinIR & Real-ESRGAN for <1GB VRAM""" | |
| MODEL_URLS = { | |
| 'swinir_lightweight': 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', | |
| 'realesrgan_compact': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus_netD.pth', | |
| } | |
| def __init__(self, model_type='swinir', device=None): | |
| """Initialize compact enhancer""" | |
| self.model_type = model_type | |
| # Device configuration | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| self.device = torch.device('cuda') | |
| # Aggressive memory management for <1GB usage | |
| torch.cuda.set_per_process_memory_fraction(0.5) # Use max 50% of VRAM | |
| torch.backends.cudnn.benchmark = False # Save memory | |
| torch.backends.cudnn.deterministic = True | |
| print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") | |
| # Get actual VRAM | |
| props = torch.cuda.get_device_properties(0) | |
| vram_gb = props.total_memory / (1024**3) | |
| print(f"📊 Total VRAM: {vram_gb:.1f} GB") | |
| # Adjust tile size based on available VRAM | |
| if vram_gb < 4: | |
| self.tile_size = 128 # Very small tiles for <4GB | |
| self.tile_pad = 8 | |
| else: | |
| self.tile_size = 192 | |
| self.tile_pad = 16 | |
| else: | |
| self.device = torch.device('cpu') | |
| self.tile_size = 256 | |
| self.tile_pad = 16 | |
| print("💻 Using CPU") | |
| else: | |
| self.device = device | |
| self.tile_size = 128 | |
| self.tile_pad = 8 | |
| # Model directory | |
| self.model_dir = 'models_compact' | |
| os.makedirs(self.model_dir, exist_ok=True) | |
| # Initialize model | |
| self.model = None | |
| self.load_model() | |
| def load_model(self): | |
| """Load compact model""" | |
| try: | |
| print(f"🔄 Loading compact {self.model_type} model...") | |
| if self.model_type == 'swinir': | |
| # Compact SwinIR configuration | |
| self.model = CompactSwinIR( | |
| upscale=4, | |
| img_size=64, | |
| window_size=8, | |
| embed_dim=60, # Reduced from 180 | |
| depths=[4], # Reduced from [6,6,6,6] | |
| num_heads=[6] # Reduced from [6,6,6,6] | |
| ) | |
| model_size = sum(p.numel() for p in self.model.parameters()) * 4 / (1024**2) | |
| print(f"📦 SwinIR Lightweight model size: {model_size:.1f} MB") | |
| elif self.model_type == 'realesrgan': | |
| # Compact Real-ESRGAN | |
| self.model = CompactRRDBNet( | |
| in_nc=3, | |
| out_nc=3, | |
| nf=32, # Reduced from 64 | |
| nb=6, # Reduced from 23 | |
| gc=16 # Reduced from 32 | |
| ) | |
| model_size = sum(p.numel() for p in self.model.parameters()) * 4 / (1024**2) | |
| print(f"📦 Real-ESRGAN Compact model size: {model_size:.1f} MB") | |
| else: | |
| raise ValueError(f"Unknown model type: {self.model_type}") | |
| # Move to device | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # Use half precision on GPU to save memory | |
| if self.device.type == 'cuda': | |
| self.model = self.model.half() | |
| print("✅ Using FP16 for memory efficiency") | |
| # Try to load pretrained weights if available | |
| model_path = os.path.join(self.model_dir, f'{self.model_type}_compact.pth') | |
| if os.path.exists(model_path): | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| print(f"✅ Loaded pretrained {self.model_type} weights") | |
| else: | |
| print(f"⚠️ No pretrained weights found, using random initialization") | |
| print(f" Model will still work but quality may be lower") | |
| print(f"✅ Model ready! Estimated VRAM usage: <500MB") | |
| except Exception as e: | |
| print(f"❌ Failed to load model: {e}") | |
| self.model = None | |
| def enhance_image(self, image_path: str, output_path: str = None) -> str: | |
| """Enhance image with compact model""" | |
| if output_path is None: | |
| output_path = image_path.replace('.', '_enhanced.') | |
| print(f"🎨 Enhancing {os.path.basename(image_path)} with {self.model_type}...") | |
| try: | |
| # Load image | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| print(f"❌ Failed to load image: {image_path}") | |
| return image_path | |
| h, w = img.shape[:2] | |
| print(f" Input size: {w}x{h}") | |
| # Clear cache before processing | |
| if self.device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # Enhance | |
| if self.model is not None: | |
| enhanced = self.process_with_tiling(img) | |
| else: | |
| # Fallback | |
| print(" ⚠️ Using fallback upscaling") | |
| enhanced = self.fallback_upscale(img) | |
| # Save result | |
| cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) | |
| new_h, new_w = enhanced.shape[:2] | |
| print(f" ✅ Output size: {new_w}x{new_h}") | |
| # Clear memory after processing | |
| if self.device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| return output_path | |
| except torch.cuda.OutOfMemoryError: | |
| print(" ❌ CUDA OOM! Falling back to CPU") | |
| self.device = torch.device('cpu') | |
| if self.model: | |
| self.model = self.model.cpu().float() | |
| return self.enhance_image(image_path, output_path) | |
| except Exception as e: | |
| print(f" ❌ Enhancement failed: {e}") | |
| return image_path | |
| def process_with_tiling(self, img): | |
| """Process image with tiling for minimal VRAM usage""" | |
| # Prepare image | |
| img_tensor = self.img_to_tensor(img) | |
| _, _, h, w = img_tensor.shape | |
| # Calculate output size | |
| out_h, out_w = h * 4, w * 4 | |
| # Prepare output tensor on CPU to save VRAM | |
| output = torch.zeros((1, 3, out_h, out_w), dtype=torch.float32, device='cpu') | |
| # Process tiles | |
| tile_size = self.tile_size | |
| pad = self.tile_pad | |
| print(f" Processing with {tile_size}x{tile_size} tiles...") | |
| for y in range(0, h, tile_size - pad * 2): | |
| for x in range(0, w, tile_size - pad * 2): | |
| # Calculate tile boundaries with padding | |
| x_start = max(0, x - pad) | |
| y_start = max(0, y - pad) | |
| x_end = min(w, x + tile_size - pad) | |
| y_end = min(h, y + tile_size - pad) | |
| # Extract tile | |
| tile = img_tensor[:, :, y_start:y_end, x_start:x_end] | |
| # Move tile to device | |
| tile = tile.to(self.device) | |
| if self.device.type == 'cuda' and self.model.training == False: | |
| tile = tile.half() | |
| # Process tile | |
| with torch.no_grad(): | |
| enhanced_tile = self.model(tile) | |
| # Move result back to CPU immediately | |
| enhanced_tile = enhanced_tile.cpu().float() | |
| # Calculate output coordinates (excluding padding) | |
| out_x_start = x * 4 | |
| out_y_start = y * 4 | |
| out_x_end = min(out_w, (x + tile_size - pad * 2) * 4) | |
| out_y_end = min(out_h, (y + tile_size - pad * 2) * 4) | |
| # Calculate tile coordinates (excluding padding) | |
| tile_x_start = pad * 4 if x > 0 else 0 | |
| tile_y_start = pad * 4 if y > 0 else 0 | |
| tile_x_end = tile_x_start + (out_x_end - out_x_start) | |
| tile_y_end = tile_y_start + (out_y_end - out_y_start) | |
| # Place tile in output | |
| output[:, :, out_y_start:out_y_end, out_x_start:out_x_end] = \ | |
| enhanced_tile[:, :, tile_y_start:tile_y_end, tile_x_start:tile_x_end] | |
| # Clear tile from GPU memory immediately | |
| del tile, enhanced_tile | |
| if self.device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| # Convert back to image | |
| return self.tensor_to_img(output) | |
| def img_to_tensor(self, img): | |
| """Convert image to tensor""" | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = img.astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) | |
| return img_tensor | |
| def tensor_to_img(self, tensor): | |
| """Convert tensor to image""" | |
| img = tensor.squeeze(0).permute(1, 2, 0).numpy() | |
| img = (img * 255).clip(0, 255).astype(np.uint8) | |
| return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| def fallback_upscale(self, img): | |
| """High-quality fallback upscaling""" | |
| h, w = img.shape[:2] | |
| # EDSR-inspired upscaling (max 2K) | |
| scale = min(2, 2048/w, 1080/h) | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) | |
| # Enhance sharpness | |
| kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1 | |
| upscaled = cv2.filter2D(upscaled, -1, kernel) | |
| # Denoise | |
| upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50) | |
| return upscaled | |
| def get_memory_usage(self): | |
| """Get current memory usage""" | |
| if self.device.type == 'cuda': | |
| allocated = torch.cuda.memory_allocated() / (1024**2) | |
| reserved = torch.cuda.memory_reserved() / (1024**2) | |
| return f"Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB" | |
| return "Using CPU" | |
| # Easy-to-use functions | |
| def create_compact_enhancer(model_type='swinir'): | |
| """Create a compact enhancer that works with <1GB VRAM""" | |
| return CompactAIEnhancer(model_type=model_type) | |
| def enhance_with_swinir(image_path, output_path=None): | |
| """Enhance image with compact SwinIR""" | |
| enhancer = CompactAIEnhancer(model_type='swinir') | |
| return enhancer.enhance_image(image_path, output_path) | |
| def enhance_with_compact_realesrgan(image_path, output_path=None): | |
| """Enhance image with compact Real-ESRGAN""" | |
| enhancer = CompactAIEnhancer(model_type='realesrgan') | |
| return enhancer.enhance_image(image_path, output_path) | |
| if __name__ == "__main__": | |
| print("🚀 Compact AI Models for <1GB VRAM") | |
| print("=" * 50) | |
| # Test both models | |
| enhancer = CompactAIEnhancer(model_type='swinir') | |
| print(f"\nMemory usage: {enhancer.get_memory_usage()}") | |
| enhancer2 = CompactAIEnhancer(model_type='realesrgan') | |
| print(f"Memory usage: {enhancer2.get_memory_usage()}") |