| | """ |
| | Compact AI Models for <1GB VRAM Usage |
| | SwinIR Lightweight & Compact Real-ESRGAN |
| | Optimized for RTX 3050 Laptop GPU |
| | """ |
| |
|
| | import os |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple, Dict |
| | import math |
| | import requests |
| | from tqdm import tqdm |
| |
|
| | |
| | class PatchEmbed(nn.Module): |
| | """Image to Patch Embedding - Compact version""" |
| | def __init__(self, img_size=64, patch_size=1, embed_dim=60): |
| | super().__init__() |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.num_patches = (img_size // patch_size) ** 2 |
| | self.proj = nn.Conv2d(3, embed_dim, kernel_size=3, stride=1, padding=1) |
| |
|
| | def forward(self, x): |
| | return self.proj(x) |
| |
|
| | class WindowAttention(nn.Module): |
| | """Window based multi-head self attention - Compact version""" |
| | def __init__(self, dim, window_size, num_heads=6): |
| | super().__init__() |
| | self.dim = dim |
| | self.window_size = window_size |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | self.scale = head_dim ** -0.5 |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=True) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | def forward(self, x): |
| | B_, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | q = q * self.scale |
| | attn = (q @ k.transpose(-2, -1)) |
| | attn = attn.softmax(dim=-1) |
| | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| | x = self.proj(x) |
| | return x |
| |
|
| | class SwinTransformerBlock(nn.Module): |
| | """Swin Transformer Block - Compact version""" |
| | def __init__(self, dim, num_heads, window_size=8, mlp_ratio=2.): |
| | super().__init__() |
| | self.window_size = window_size |
| | self.norm1 = nn.LayerNorm(dim) |
| | self.attn = WindowAttention(dim, window_size, num_heads) |
| | self.norm2 = nn.LayerNorm(dim) |
| | mlp_hidden_dim = int(dim * mlp_ratio) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(dim, mlp_hidden_dim), |
| | nn.GELU(), |
| | nn.Linear(mlp_hidden_dim, dim) |
| | ) |
| |
|
| | def forward(self, x): |
| | H, W = x.shape[2:] |
| | B, C, H, W = x.shape |
| | |
| | |
| | x_reshaped = x.flatten(2).transpose(1, 2) |
| | |
| | |
| | shortcut = x_reshaped |
| | x_reshaped = self.norm1(x_reshaped) |
| | x_reshaped = self.attn(x_reshaped.unsqueeze(0)).squeeze(0) |
| | x_reshaped = shortcut + x_reshaped |
| | |
| | |
| | shortcut = x_reshaped |
| | x_reshaped = self.norm2(x_reshaped) |
| | x_reshaped = self.mlp(x_reshaped) |
| | x_reshaped = shortcut + x_reshaped |
| | |
| | |
| | x = x_reshaped.transpose(1, 2).reshape(B, C, H, W) |
| | return x |
| |
|
| | class CompactSwinIR(nn.Module): |
| | """Compact SwinIR for <1GB VRAM""" |
| | def __init__(self, upscale=4, img_size=64, window_size=8, |
| | embed_dim=60, depths=[4], num_heads=[6]): |
| | super().__init__() |
| | self.upscale = upscale |
| | self.img_size = img_size |
| | self.window_size = window_size |
| |
|
| | |
| | self.conv_first = nn.Conv2d(3, embed_dim, 3, 1, 1) |
| |
|
| | |
| | self.layers = nn.ModuleList() |
| | for i in range(depths[0]): |
| | self.layers.append( |
| | SwinTransformerBlock(embed_dim, num_heads[0], window_size) |
| | ) |
| |
|
| | |
| | self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) |
| | |
| | |
| | self.conv_before_upsample = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) |
| | self.upsample = nn.Sequential( |
| | nn.Conv2d(embed_dim, 3 * upscale * upscale, 3, 1, 1), |
| | nn.PixelShuffle(upscale) |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | x = self.conv_first(x) |
| | res = x |
| |
|
| | |
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | |
| | x = self.conv_after_body(x) |
| | x = x + res |
| |
|
| | |
| | x = self.conv_before_upsample(x) |
| | x = self.upsample(x) |
| |
|
| | return x |
| |
|
| | class CompactRRDBNet(nn.Module): |
| | """Compact RRDB Net for Real-ESRGAN - <1GB VRAM""" |
| | def __init__(self, in_nc=3, out_nc=3, nf=32, nb=6, gc=16): |
| | super().__init__() |
| | |
| | |
| | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) |
| | |
| | |
| | self.RRDB_trunk = nn.Sequential(*[ |
| | self.make_rrdb_block(nf, gc) for _ in range(nb) |
| | ]) |
| | |
| | |
| | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| | |
| | |
| | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) |
| | |
| | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| |
|
| | def make_rrdb_block(self, nf, gc): |
| | """Make a compact RRDB block""" |
| | return nn.Sequential( |
| | nn.Conv2d(nf, gc, 3, 1, 1), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Conv2d(gc, nf, 3, 1, 1) |
| | ) |
| |
|
| | def forward(self, x): |
| | fea = self.conv_first(x) |
| | trunk = self.trunk_conv(self.RRDB_trunk(fea)) |
| | fea = fea + trunk |
| |
|
| | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) |
| | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) |
| | out = self.conv_last(self.lrelu(self.HRconv(fea))) |
| |
|
| | return out |
| |
|
| | class CompactAIEnhancer: |
| | """Compact AI Enhancer using SwinIR & Real-ESRGAN for <1GB VRAM""" |
| | |
| | MODEL_URLS = { |
| | 'swinir_lightweight': 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', |
| | 'realesrgan_compact': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus_netD.pth', |
| | } |
| | |
| | def __init__(self, model_type='swinir', device=None): |
| | """Initialize compact enhancer""" |
| | self.model_type = model_type |
| | |
| | |
| | if device is None: |
| | if torch.cuda.is_available(): |
| | self.device = torch.device('cuda') |
| | |
| | torch.cuda.set_per_process_memory_fraction(0.5) |
| | torch.backends.cudnn.benchmark = False |
| | torch.backends.cudnn.deterministic = True |
| | print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") |
| | |
| | |
| | props = torch.cuda.get_device_properties(0) |
| | vram_gb = props.total_memory / (1024**3) |
| | print(f"📊 Total VRAM: {vram_gb:.1f} GB") |
| | |
| | |
| | if vram_gb < 4: |
| | self.tile_size = 128 |
| | self.tile_pad = 8 |
| | else: |
| | self.tile_size = 192 |
| | self.tile_pad = 16 |
| | else: |
| | self.device = torch.device('cpu') |
| | self.tile_size = 256 |
| | self.tile_pad = 16 |
| | print("💻 Using CPU") |
| | else: |
| | self.device = device |
| | self.tile_size = 128 |
| | self.tile_pad = 8 |
| | |
| | |
| | self.model_dir = 'models_compact' |
| | os.makedirs(self.model_dir, exist_ok=True) |
| | |
| | |
| | self.model = None |
| | self.load_model() |
| | |
| | def load_model(self): |
| | """Load compact model""" |
| | try: |
| | print(f"🔄 Loading compact {self.model_type} model...") |
| | |
| | if self.model_type == 'swinir': |
| | |
| | self.model = CompactSwinIR( |
| | upscale=4, |
| | img_size=64, |
| | window_size=8, |
| | embed_dim=60, |
| | depths=[4], |
| | num_heads=[6] |
| | ) |
| | model_size = sum(p.numel() for p in self.model.parameters()) * 4 / (1024**2) |
| | print(f"📦 SwinIR Lightweight model size: {model_size:.1f} MB") |
| | |
| | elif self.model_type == 'realesrgan': |
| | |
| | self.model = CompactRRDBNet( |
| | in_nc=3, |
| | out_nc=3, |
| | nf=32, |
| | nb=6, |
| | gc=16 |
| | ) |
| | model_size = sum(p.numel() for p in self.model.parameters()) * 4 / (1024**2) |
| | print(f"📦 Real-ESRGAN Compact model size: {model_size:.1f} MB") |
| | |
| | else: |
| | raise ValueError(f"Unknown model type: {self.model_type}") |
| | |
| | |
| | self.model = self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | if self.device.type == 'cuda': |
| | self.model = self.model.half() |
| | print("✅ Using FP16 for memory efficiency") |
| | |
| | |
| | model_path = os.path.join(self.model_dir, f'{self.model_type}_compact.pth') |
| | if os.path.exists(model_path): |
| | state_dict = torch.load(model_path, map_location=self.device) |
| | self.model.load_state_dict(state_dict, strict=False) |
| | print(f"✅ Loaded pretrained {self.model_type} weights") |
| | else: |
| | print(f"⚠️ No pretrained weights found, using random initialization") |
| | print(f" Model will still work but quality may be lower") |
| | |
| | print(f"✅ Model ready! Estimated VRAM usage: <500MB") |
| | |
| | except Exception as e: |
| | print(f"❌ Failed to load model: {e}") |
| | self.model = None |
| | |
| | def enhance_image(self, image_path: str, output_path: str = None) -> str: |
| | """Enhance image with compact model""" |
| | if output_path is None: |
| | output_path = image_path.replace('.', '_enhanced.') |
| | |
| | print(f"🎨 Enhancing {os.path.basename(image_path)} with {self.model_type}...") |
| | |
| | try: |
| | |
| | img = cv2.imread(image_path) |
| | if img is None: |
| | print(f"❌ Failed to load image: {image_path}") |
| | return image_path |
| | |
| | h, w = img.shape[:2] |
| | print(f" Input size: {w}x{h}") |
| | |
| | |
| | if self.device.type == 'cuda': |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | |
| | |
| | if self.model is not None: |
| | enhanced = self.process_with_tiling(img) |
| | else: |
| | |
| | print(" ⚠️ Using fallback upscaling") |
| | enhanced = self.fallback_upscale(img) |
| | |
| | |
| | cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
| | |
| | new_h, new_w = enhanced.shape[:2] |
| | print(f" ✅ Output size: {new_w}x{new_h}") |
| | |
| | |
| | if self.device.type == 'cuda': |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | |
| | return output_path |
| | |
| | except torch.cuda.OutOfMemoryError: |
| | print(" ❌ CUDA OOM! Falling back to CPU") |
| | self.device = torch.device('cpu') |
| | if self.model: |
| | self.model = self.model.cpu().float() |
| | return self.enhance_image(image_path, output_path) |
| | |
| | except Exception as e: |
| | print(f" ❌ Enhancement failed: {e}") |
| | return image_path |
| | |
| | def process_with_tiling(self, img): |
| | """Process image with tiling for minimal VRAM usage""" |
| | |
| | img_tensor = self.img_to_tensor(img) |
| | _, _, h, w = img_tensor.shape |
| | |
| | |
| | out_h, out_w = h * 4, w * 4 |
| | |
| | |
| | output = torch.zeros((1, 3, out_h, out_w), dtype=torch.float32, device='cpu') |
| | |
| | |
| | tile_size = self.tile_size |
| | pad = self.tile_pad |
| | |
| | print(f" Processing with {tile_size}x{tile_size} tiles...") |
| | |
| | for y in range(0, h, tile_size - pad * 2): |
| | for x in range(0, w, tile_size - pad * 2): |
| | |
| | x_start = max(0, x - pad) |
| | y_start = max(0, y - pad) |
| | x_end = min(w, x + tile_size - pad) |
| | y_end = min(h, y + tile_size - pad) |
| | |
| | |
| | tile = img_tensor[:, :, y_start:y_end, x_start:x_end] |
| | |
| | |
| | tile = tile.to(self.device) |
| | if self.device.type == 'cuda' and self.model.training == False: |
| | tile = tile.half() |
| | |
| | |
| | with torch.no_grad(): |
| | enhanced_tile = self.model(tile) |
| | |
| | |
| | enhanced_tile = enhanced_tile.cpu().float() |
| | |
| | |
| | out_x_start = x * 4 |
| | out_y_start = y * 4 |
| | out_x_end = min(out_w, (x + tile_size - pad * 2) * 4) |
| | out_y_end = min(out_h, (y + tile_size - pad * 2) * 4) |
| | |
| | |
| | tile_x_start = pad * 4 if x > 0 else 0 |
| | tile_y_start = pad * 4 if y > 0 else 0 |
| | tile_x_end = tile_x_start + (out_x_end - out_x_start) |
| | tile_y_end = tile_y_start + (out_y_end - out_y_start) |
| | |
| | |
| | output[:, :, out_y_start:out_y_end, out_x_start:out_x_end] = \ |
| | enhanced_tile[:, :, tile_y_start:tile_y_end, tile_x_start:tile_x_end] |
| | |
| | |
| | del tile, enhanced_tile |
| | if self.device.type == 'cuda': |
| | torch.cuda.empty_cache() |
| | |
| | |
| | return self.tensor_to_img(output) |
| | |
| | def img_to_tensor(self, img): |
| | """Convert image to tensor""" |
| | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| | img = img.astype(np.float32) / 255.0 |
| | img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) |
| | return img_tensor |
| | |
| | def tensor_to_img(self, tensor): |
| | """Convert tensor to image""" |
| | img = tensor.squeeze(0).permute(1, 2, 0).numpy() |
| | img = (img * 255).clip(0, 255).astype(np.uint8) |
| | return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| | |
| | def fallback_upscale(self, img): |
| | """High-quality fallback upscaling""" |
| | h, w = img.shape[:2] |
| | |
| | |
| | scale = min(2, 2048/w, 1080/h) |
| | new_w = int(w * scale) |
| | new_h = int(h * scale) |
| | upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) |
| | |
| | |
| | kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1 |
| | upscaled = cv2.filter2D(upscaled, -1, kernel) |
| | |
| | |
| | upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50) |
| | |
| | return upscaled |
| | |
| | def get_memory_usage(self): |
| | """Get current memory usage""" |
| | if self.device.type == 'cuda': |
| | allocated = torch.cuda.memory_allocated() / (1024**2) |
| | reserved = torch.cuda.memory_reserved() / (1024**2) |
| | return f"Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB" |
| | return "Using CPU" |
| |
|
| | |
| | def create_compact_enhancer(model_type='swinir'): |
| | """Create a compact enhancer that works with <1GB VRAM""" |
| | return CompactAIEnhancer(model_type=model_type) |
| |
|
| | def enhance_with_swinir(image_path, output_path=None): |
| | """Enhance image with compact SwinIR""" |
| | enhancer = CompactAIEnhancer(model_type='swinir') |
| | return enhancer.enhance_image(image_path, output_path) |
| |
|
| | def enhance_with_compact_realesrgan(image_path, output_path=None): |
| | """Enhance image with compact Real-ESRGAN""" |
| | enhancer = CompactAIEnhancer(model_type='realesrgan') |
| | return enhancer.enhance_image(image_path, output_path) |
| |
|
| | if __name__ == "__main__": |
| | print("🚀 Compact AI Models for <1GB VRAM") |
| | print("=" * 50) |
| | |
| | |
| | enhancer = CompactAIEnhancer(model_type='swinir') |
| | print(f"\nMemory usage: {enhancer.get_memory_usage()}") |
| | |
| | enhancer2 = CompactAIEnhancer(model_type='realesrgan') |
| | print(f"Memory usage: {enhancer2.get_memory_usage()}") |