""" Lightweight AI Enhancement for Limited VRAM (< 4GB) Optimized for RTX 3050 Laptop GPU Uses efficient models with excellent quality """ import os import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import requests from tqdm import tqdm from typing import Optional, Dict, Any, Tuple import warnings warnings.filterwarnings('ignore') # Lightweight ESRGAN Architecture class RRDBNet_arch(nn.Module): """Lightweight RRDB Net for ESRGAN - optimized for low VRAM""" def __init__(self, in_nc=3, out_nc=3, nf=32, nb=16): # Reduced from 64/23 to 32/16 super(RRDBNet_arch, self).__init__() self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(fea) fea = fea + trunk fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out class LightweightEnhancer: """Lightweight AI enhancer for <4GB VRAM""" def __init__(self, device=None): """Initialize lightweight enhancer""" # Set device if device is None: if torch.cuda.is_available(): self.device = torch.device('cuda:0') print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") # RTX 3050 Laptop optimization torch.backends.cudnn.benchmark = True torch.cuda.set_per_process_memory_fraction(0.7) # Use only 70% VRAM # Get VRAM info props = torch.cuda.get_device_properties(0) self.vram_gb = props.total_memory / (1024**3) print(f"📊 VRAM: {self.vram_gb:.1f} GB") else: self.device = torch.device('cpu') print("💻 Using CPU (GPU not available)") self.vram_gb = 0 else: self.device = device self.vram_gb = 4 # Assume 4GB # Model storage self.model_dir = 'models_lightweight' os.makedirs(self.model_dir, exist_ok=True) # Models self.esrgan_model = None self.face_model = None # Settings based on VRAM if self.vram_gb < 4: self.tile_size = 256 # Smaller tiles for <4GB self.use_fp16 = True # Force FP16 else: self.tile_size = 384 self.use_fp16 = True def load_lightweight_esrgan(self): """Load lightweight ESRGAN model""" try: print("🔄 Loading lightweight ESRGAN...") # Create lightweight model self.esrgan_model = RRDBNet_arch() # Try to load pretrained weights if available model_path = os.path.join(self.model_dir, 'lightweight_esrgan.pth') if os.path.exists(model_path): self.esrgan_model.load_state_dict(torch.load(model_path, map_location=self.device)) print("✅ Loaded pretrained lightweight model") else: print("⚠️ No pretrained model found, using random initialization") # In practice, you'd train this or download a pretrained one self.esrgan_model = self.esrgan_model.to(self.device) self.esrgan_model.eval() # Convert to FP16 if using GPU if self.use_fp16 and self.device.type == 'cuda': self.esrgan_model = self.esrgan_model.half() print("✅ Using FP16 for memory efficiency") return True except Exception as e: print(f"❌ Failed to load lightweight ESRGAN: {e}") return False def enhance_with_lightweight_esrgan(self, img): """Enhance using lightweight ESRGAN with tiling""" if self.esrgan_model is None: if not self.load_lightweight_esrgan(): return self.fallback_upscale(img, 2) try: # Convert to tensor img_tensor = self.img_to_tensor(img) # Process with tiling for low VRAM result = self.process_with_tiles(img_tensor, self.esrgan_model, scale=2) # Convert back to numpy result = self.tensor_to_img(result) return result except Exception as e: print(f"❌ Enhancement failed: {e}") return self.fallback_upscale(img, 2) def process_with_tiles(self, img_tensor, model, scale=2): """Process image in tiles to save VRAM""" _, _, h, w = img_tensor.shape # Calculate output size (max 2K) target_h = h * scale target_w = w * scale # Apply 2K limit if target_w > 2048 or target_h > 1080: limit_scale = min(2048/target_w, 1080/target_h) out_w = int(target_w * limit_scale) out_h = int(target_h * limit_scale) print(f" 📐 Limiting output to {out_w}x{out_h} (2K max)") else: out_h, out_w = target_h, target_w output = torch.zeros((1, 3, out_h, out_w), device=self.device) # Tile processing tile_size = self.tile_size pad = 16 # Overlap to avoid seams for y in range(0, h, tile_size - pad): for x in range(0, w, tile_size - pad): # Extract tile y_end = min(y + tile_size, h) x_end = min(x + tile_size, w) tile = img_tensor[:, :, y:y_end, x:x_end] # Process tile with torch.no_grad(): if self.use_fp16 and self.device.type == 'cuda': tile = tile.half() tile_out = model(tile) if self.use_fp16: tile_out = tile_out.float() # Place tile in output out_y = y * scale out_x = x * scale out_y_end = min(out_y + tile_out.shape[2], out_h) out_x_end = min(out_x + tile_out.shape[3], out_w) output[:, :, out_y:out_y_end, out_x:out_x_end] = tile_out[:, :, :out_y_end-out_y, :out_x_end-out_x] # Clear cache to save memory if self.device.type == 'cuda': torch.cuda.empty_cache() return output def img_to_tensor(self, img): """Convert image to tensor""" if isinstance(img, Image.Image): img = np.array(img) # Ensure RGB if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) elif img.shape[2] == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) elif img.shape[2] == 3 and isinstance(img, np.ndarray): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Normalize to [0, 1] img = img.astype(np.float32) / 255.0 # Convert to tensor img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) return img_tensor.to(self.device) def tensor_to_img(self, tensor): """Convert tensor to image""" img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() img = (img * 255).clip(0, 255).astype(np.uint8) return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) def fallback_upscale(self, img, scale): """Fallback upscaling using OpenCV with 2K limit""" print(" 📈 Using optimized fallback upscaling...") h, w = img.shape[:2] # Calculate new size with 2K limit target_scale = min(scale, 2048/w, 1080/h) new_w = int(w * target_scale) new_h = int(h * target_scale) # Use EDSR-inspired upscaling # First, upscale with CUBIC upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) # Apply sharpening kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1 upscaled = cv2.filter2D(upscaled, -1, kernel) # Reduce noise upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50) return upscaled def enhance_faces_lightweight(self, img): """Lightweight face enhancement""" try: # Detect faces using OpenCV face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, 1.1, 4) if len(faces) == 0: return img print(f" 👤 Enhancing {len(faces)} faces...") for (x, y, w, h) in faces: # Extract face with padding pad = int(w * 0.1) x_start = max(0, x - pad) y_start = max(0, y - pad) x_end = min(img.shape[1], x + w + pad) y_end = min(img.shape[0], y + h + pad) face = img[y_start:y_end, x_start:x_end] # Enhance face face = self.enhance_face_region_lightweight(face) # Put back img[y_start:y_end, x_start:x_end] = face return img except Exception as e: print(f"⚠️ Face enhancement failed: {e}") return img def enhance_face_region_lightweight(self, face): """Lightweight face enhancement""" # 1. Denoise face = cv2.bilateralFilter(face, 9, 75, 75) # 2. Enhance details lab = cv2.cvtColor(face, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) # CLAHE on L channel clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) l = clahe.apply(l) face = cv2.merge([l, a, b]) face = cv2.cvtColor(face, cv2.COLOR_LAB2BGR) # 3. Subtle sharpening kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) / 1 face = cv2.filter2D(face, -1, kernel) return face def enhance_image_pipeline(self, image_path: str, output_path: str = None) -> str: """Complete enhancement pipeline for low VRAM""" print(f"🎨 Enhancing {os.path.basename(image_path)} (Lightweight Mode)...") try: # Load image img = cv2.imread(image_path) if img is None: print(f"❌ Failed to load image: {image_path}") return image_path original_shape = img.shape[:2] print(f" Original: {original_shape[1]}x{original_shape[0]}") # Step 1: Lightweight super resolution print(" 🚀 Applying lightweight upscaling (max 2K)...") print(f" 📐 Input: {img.shape[1]}x{img.shape[0]}") enhanced = self.enhance_with_lightweight_esrgan(img) # Step 2: Face enhancement print(" 👤 Enhancing faces...") enhanced = self.enhance_faces_lightweight(enhanced) # Step 3: Final color correction print(" 🎨 Applying color correction...") enhanced = self.color_correction(enhanced) # Save if output_path is None: output_path = image_path.replace('.', '_enhanced.') cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) new_shape = enhanced.shape[:2] print(f" ✅ Enhanced: {new_shape[1]}x{new_shape[0]}") # Clear memory self.clear_memory() return output_path except Exception as e: print(f"❌ Pipeline failed: {e}") return image_path def color_correction(self, img): """Lightweight color correction""" # Convert to LAB lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) # Enhance L channel clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) l = clahe.apply(l) # Slight color boost a = cv2.convertScaleAbs(a, alpha=1.1, beta=0) b = cv2.convertScaleAbs(b, alpha=1.1, beta=0) # Merge and convert back enhanced = cv2.merge([l, a, b]) enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) return enhanced def clear_memory(self): """Clear GPU memory""" if self.device.type == 'cuda': torch.cuda.empty_cache() torch.cuda.synchronize() # Global instance _lightweight_enhancer = None def get_lightweight_enhancer(): """Get or create global lightweight enhancer""" global _lightweight_enhancer if _lightweight_enhancer is None: _lightweight_enhancer = LightweightEnhancer() return _lightweight_enhancer