| """ |
| Ultra Compact Image Enhancer for Extreme Memory Constraints |
| Designed for RTX 3050 Laptop with strict <1GB VRAM limit |
| """ |
|
|
| import os |
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| import gc |
|
|
| class UltraCompactESRGAN(nn.Module): |
| """Ultra lightweight ESRGAN - only 200MB VRAM usage""" |
| def __init__(self, scale=2): |
| super().__init__() |
| self.scale = scale |
| |
| |
| nf = 24 |
| |
| self.conv1 = nn.Conv2d(3, nf, 3, 1, 1) |
| self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1) |
| self.conv3 = nn.Conv2d(nf, nf, 3, 1, 1) |
| |
| |
| self.upscale = nn.Sequential( |
| nn.Conv2d(nf, 3 * scale * scale, 3, 1, 1), |
| nn.PixelShuffle(scale) |
| ) |
| |
| self.act = nn.ReLU(inplace=True) |
| |
| def forward(self, x): |
| |
| x1 = self.act(self.conv1(x)) |
| x2 = self.act(self.conv2(x1)) |
| x3 = self.conv3(x2) |
| x = x1 + x3 |
| x = self.upscale(x) |
| return x |
|
|
| class MemorySafeEnhancer: |
| """Memory-safe enhancer that guarantees <1GB VRAM usage""" |
| |
| def __init__(self): |
| self.device = self._setup_device() |
| self.model = None |
| self.tile_size = 64 |
| self.scale = 2 |
| |
| |
| self._load_model() |
| |
| def _setup_device(self): |
| """Setup device with strict memory limits""" |
| if torch.cuda.is_available(): |
| |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| |
| |
| torch.cuda.set_per_process_memory_fraction(0.3) |
| |
| device = torch.device('cuda') |
| print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") |
| |
| |
| total = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| print(f"📊 Total VRAM: {total:.1f}GB, Using max: {total*0.3:.1f}GB") |
| else: |
| device = torch.device('cpu') |
| print("💻 Using CPU") |
| |
| return device |
| |
| def _load_model(self): |
| """Load ultra compact model""" |
| try: |
| print("🔄 Loading ultra-compact model...") |
| |
| self.model = UltraCompactESRGAN(scale=self.scale) |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| |
| |
| if self.device.type == 'cuda': |
| self.model = self.model.half() |
| |
| |
| param_size = sum(p.numel() for p in self.model.parameters()) |
| model_mb = param_size * 2 / (1024**2) |
| print(f"✅ Model loaded: {model_mb:.1f}MB") |
| |
| except Exception as e: |
| print(f"❌ Model loading failed: {e}") |
| self.model = None |
| |
| def enhance_image(self, image_path: str, output_path: str = None) -> str: |
| """Enhance image with guaranteed low memory usage""" |
| if output_path is None: |
| output_path = image_path.replace('.', '_enhanced.') |
| |
| print(f"🎨 Enhancing {os.path.basename(image_path)}...") |
| |
| try: |
| |
| img = cv2.imread(image_path) |
| if img is None: |
| print(f"❌ Failed to read image") |
| return image_path |
| |
| h, w = img.shape[:2] |
| print(f" Input: {w}x{h}") |
| |
| |
| if h > 2048 or w > 2048: |
| print(" ⚠️ Large image, using CPU fallback") |
| enhanced = self._cpu_upscale(img) |
| elif self.model is not None: |
| enhanced = self._enhance_with_model(img) |
| else: |
| enhanced = self._cpu_upscale(img) |
| |
| |
| h, w = enhanced.shape[:2] |
| if w > 2048 or h > 1080: |
| scale = min(2048/w, 1080/h) |
| new_w = int(w * scale) |
| new_h = int(h * scale) |
| enhanced = cv2.resize(enhanced, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) |
| print(f" 📐 Resizing from {w}x{h} to {new_w}x{new_h} (2K limit)") |
| |
| |
| cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
| |
| new_h, new_w = enhanced.shape[:2] |
| print(f" ✅ Output: {new_w}x{new_h}") |
| |
| |
| self._cleanup_memory() |
| |
| return output_path |
| |
| except Exception as e: |
| print(f" ❌ Enhancement failed: {e}") |
| |
| try: |
| img = cv2.imread(image_path) |
| enhanced = self._cpu_upscale(img) |
| cv2.imwrite(output_path, enhanced) |
| return output_path |
| except: |
| return image_path |
| |
| def _enhance_with_model(self, img): |
| """Enhance using model with extreme memory safety""" |
| h, w = img.shape[:2] |
| |
| |
| output = np.zeros((h * self.scale, w * self.scale, 3), dtype=np.uint8) |
| |
| |
| tile_size = self.tile_size |
| |
| print(f" Processing {tile_size}x{tile_size} tiles...") |
| |
| for y in range(0, h, tile_size): |
| for x in range(0, w, tile_size): |
| |
| y_end = min(y + tile_size, h) |
| x_end = min(x + tile_size, w) |
| tile = img[y:y_end, x:x_end] |
| |
| |
| if tile.shape[0] < 4 or tile.shape[1] < 4: |
| continue |
| |
| try: |
| |
| enhanced_tile = self._process_single_tile(tile) |
| |
| |
| out_y = y * self.scale |
| out_x = x * self.scale |
| out_y_end = out_y + enhanced_tile.shape[0] |
| out_x_end = out_x + enhanced_tile.shape[1] |
| |
| output[out_y:out_y_end, out_x:out_x_end] = enhanced_tile |
| |
| except Exception as e: |
| |
| fallback = cv2.resize(tile, (tile.shape[1]*self.scale, tile.shape[0]*self.scale), |
| interpolation=cv2.INTER_CUBIC) |
| out_y = y * self.scale |
| out_x = x * self.scale |
| output[out_y:out_y+fallback.shape[0], out_x:out_x+fallback.shape[1]] = fallback |
| |
| |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| |
| return output |
| |
| def _process_single_tile(self, tile): |
| """Process a single tile with proper error handling""" |
| |
| tile_rgb = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB) |
| tile_norm = tile_rgb.astype(np.float32) / 255.0 |
| |
| |
| tile_tensor = torch.from_numpy(tile_norm).permute(2, 0, 1).unsqueeze(0) |
| tile_tensor = tile_tensor.to(self.device) |
| |
| |
| if self.device.type == 'cuda': |
| tile_tensor = tile_tensor.half() |
| |
| |
| with torch.no_grad(): |
| enhanced_tensor = self.model(tile_tensor) |
| |
| |
| enhanced = enhanced_tensor.squeeze(0).permute(1, 2, 0) |
| enhanced = enhanced.cpu().float().numpy() |
| enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8) |
| enhanced = cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR) |
| |
| |
| del tile_tensor, enhanced_tensor |
| |
| return enhanced |
| |
| def _cpu_upscale(self, img): |
| """CPU-only upscaling fallback""" |
| print(" 📈 Using CPU upscaling...") |
| |
| |
| h, w = img.shape[:2] |
| scale_factor = min(self.scale, 2048/w, 1080/h) |
| new_w = int(w * scale_factor) |
| new_h = int(h * scale_factor) |
| |
| |
| cubic = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) |
| lanczos = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) |
| |
| |
| result = cv2.addWeighted(cubic, 0.5, lanczos, 0.5, 0) |
| |
| |
| kernel = np.array([[0, -1, 0], |
| [-1, 5, -1], |
| [0, -1, 0]], dtype=np.float32) |
| result = cv2.filter2D(result, -1, kernel) |
| |
| return result |
| |
| def _cleanup_memory(self): |
| """Aggressive memory cleanup""" |
| gc.collect() |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| |
| def get_memory_usage(self): |
| """Get current memory usage""" |
| if self.device.type == 'cuda': |
| allocated = torch.cuda.memory_allocated() / (1024**2) |
| reserved = torch.cuda.memory_reserved() / (1024**2) |
| return f"Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB" |
| return "Using CPU" |
|
|
| |
| _memory_safe_enhancer = None |
|
|
| def get_memory_safe_enhancer(): |
| """Get or create memory-safe enhancer""" |
| global _memory_safe_enhancer |
| if _memory_safe_enhancer is None: |
| _memory_safe_enhancer = MemorySafeEnhancer() |
| return _memory_safe_enhancer |
|
|
| |
| def enhance_image_safe(image_path: str, output_path: str = None) -> str: |
| """Enhance image with guaranteed <1GB VRAM usage""" |
| enhancer = get_memory_safe_enhancer() |
| return enhancer.enhance_image(image_path, output_path) |