| """ |
| Lightweight AI Enhancement for Limited VRAM (< 4GB) |
| Optimized for RTX 3050 Laptop GPU |
| Uses efficient models with excellent quality |
| """ |
|
|
| import os |
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| import requests |
| from tqdm import tqdm |
| from typing import Optional, Dict, Any, Tuple |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| |
| class RRDBNet_arch(nn.Module): |
| """Lightweight RRDB Net for ESRGAN - optimized for low VRAM""" |
| def __init__(self, in_nc=3, out_nc=3, nf=32, nb=16): |
| super(RRDBNet_arch, self).__init__() |
| self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) |
| self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
| def forward(self, x): |
| fea = self.conv_first(x) |
| trunk = self.trunk_conv(fea) |
| fea = fea + trunk |
| fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) |
| fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) |
| out = self.conv_last(self.lrelu(self.HRconv(fea))) |
| return out |
|
|
| class LightweightEnhancer: |
| """Lightweight AI enhancer for <4GB VRAM""" |
| |
| def __init__(self, device=None): |
| """Initialize lightweight enhancer""" |
| |
| |
| if device is None: |
| if torch.cuda.is_available(): |
| self.device = torch.device('cuda:0') |
| print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") |
| |
| |
| torch.backends.cudnn.benchmark = True |
| torch.cuda.set_per_process_memory_fraction(0.7) |
| |
| |
| props = torch.cuda.get_device_properties(0) |
| self.vram_gb = props.total_memory / (1024**3) |
| print(f"📊 VRAM: {self.vram_gb:.1f} GB") |
| |
| else: |
| self.device = torch.device('cpu') |
| print("💻 Using CPU (GPU not available)") |
| self.vram_gb = 0 |
| else: |
| self.device = device |
| self.vram_gb = 4 |
| |
| |
| self.model_dir = 'models_lightweight' |
| os.makedirs(self.model_dir, exist_ok=True) |
| |
| |
| self.esrgan_model = None |
| self.face_model = None |
| |
| |
| if self.vram_gb < 4: |
| self.tile_size = 256 |
| self.use_fp16 = True |
| else: |
| self.tile_size = 384 |
| self.use_fp16 = True |
| |
| def load_lightweight_esrgan(self): |
| """Load lightweight ESRGAN model""" |
| try: |
| print("🔄 Loading lightweight ESRGAN...") |
| |
| |
| self.esrgan_model = RRDBNet_arch() |
| |
| |
| model_path = os.path.join(self.model_dir, 'lightweight_esrgan.pth') |
| if os.path.exists(model_path): |
| self.esrgan_model.load_state_dict(torch.load(model_path, map_location=self.device)) |
| print("✅ Loaded pretrained lightweight model") |
| else: |
| print("⚠️ No pretrained model found, using random initialization") |
| |
| |
| self.esrgan_model = self.esrgan_model.to(self.device) |
| self.esrgan_model.eval() |
| |
| |
| if self.use_fp16 and self.device.type == 'cuda': |
| self.esrgan_model = self.esrgan_model.half() |
| print("✅ Using FP16 for memory efficiency") |
| |
| return True |
| |
| except Exception as e: |
| print(f"❌ Failed to load lightweight ESRGAN: {e}") |
| return False |
| |
| def enhance_with_lightweight_esrgan(self, img): |
| """Enhance using lightweight ESRGAN with tiling""" |
| if self.esrgan_model is None: |
| if not self.load_lightweight_esrgan(): |
| return self.fallback_upscale(img, 2) |
| |
| try: |
| |
| img_tensor = self.img_to_tensor(img) |
| |
| |
| result = self.process_with_tiles(img_tensor, self.esrgan_model, scale=2) |
| |
| |
| result = self.tensor_to_img(result) |
| |
| return result |
| |
| except Exception as e: |
| print(f"❌ Enhancement failed: {e}") |
| return self.fallback_upscale(img, 2) |
| |
| def process_with_tiles(self, img_tensor, model, scale=2): |
| """Process image in tiles to save VRAM""" |
| _, _, h, w = img_tensor.shape |
| |
| |
| target_h = h * scale |
| target_w = w * scale |
| |
| |
| if target_w > 2048 or target_h > 1080: |
| limit_scale = min(2048/target_w, 1080/target_h) |
| out_w = int(target_w * limit_scale) |
| out_h = int(target_h * limit_scale) |
| print(f" 📐 Limiting output to {out_w}x{out_h} (2K max)") |
| else: |
| out_h, out_w = target_h, target_w |
| output = torch.zeros((1, 3, out_h, out_w), device=self.device) |
| |
| |
| tile_size = self.tile_size |
| pad = 16 |
| |
| for y in range(0, h, tile_size - pad): |
| for x in range(0, w, tile_size - pad): |
| |
| y_end = min(y + tile_size, h) |
| x_end = min(x + tile_size, w) |
| tile = img_tensor[:, :, y:y_end, x:x_end] |
| |
| |
| with torch.no_grad(): |
| if self.use_fp16 and self.device.type == 'cuda': |
| tile = tile.half() |
| |
| tile_out = model(tile) |
| |
| if self.use_fp16: |
| tile_out = tile_out.float() |
| |
| |
| out_y = y * scale |
| out_x = x * scale |
| out_y_end = min(out_y + tile_out.shape[2], out_h) |
| out_x_end = min(out_x + tile_out.shape[3], out_w) |
| |
| output[:, :, out_y:out_y_end, out_x:out_x_end] = tile_out[:, :, :out_y_end-out_y, :out_x_end-out_x] |
| |
| |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| |
| return output |
| |
| def img_to_tensor(self, img): |
| """Convert image to tensor""" |
| if isinstance(img, Image.Image): |
| img = np.array(img) |
| |
| |
| if len(img.shape) == 2: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| elif img.shape[2] == 4: |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
| elif img.shape[2] == 3 and isinstance(img, np.ndarray): |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
| |
| img = img.astype(np.float32) / 255.0 |
| |
| |
| img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) |
| |
| return img_tensor.to(self.device) |
| |
| def tensor_to_img(self, tensor): |
| """Convert tensor to image""" |
| img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() |
| img = (img * 255).clip(0, 255).astype(np.uint8) |
| return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| |
| def fallback_upscale(self, img, scale): |
| """Fallback upscaling using OpenCV with 2K limit""" |
| print(" 📈 Using optimized fallback upscaling...") |
| |
| h, w = img.shape[:2] |
| |
| |
| target_scale = min(scale, 2048/w, 1080/h) |
| new_w = int(w * target_scale) |
| new_h = int(h * target_scale) |
| |
| |
| |
| upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) |
| |
| |
| kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1 |
| upscaled = cv2.filter2D(upscaled, -1, kernel) |
| |
| |
| upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50) |
| |
| return upscaled |
| |
| def enhance_faces_lightweight(self, img): |
| """Lightweight face enhancement""" |
| try: |
| |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
| |
| if len(faces) == 0: |
| return img |
| |
| print(f" 👤 Enhancing {len(faces)} faces...") |
| |
| for (x, y, w, h) in faces: |
| |
| pad = int(w * 0.1) |
| x_start = max(0, x - pad) |
| y_start = max(0, y - pad) |
| x_end = min(img.shape[1], x + w + pad) |
| y_end = min(img.shape[0], y + h + pad) |
| |
| face = img[y_start:y_end, x_start:x_end] |
| |
| |
| face = self.enhance_face_region_lightweight(face) |
| |
| |
| img[y_start:y_end, x_start:x_end] = face |
| |
| return img |
| |
| except Exception as e: |
| print(f"⚠️ Face enhancement failed: {e}") |
| return img |
| |
| def enhance_face_region_lightweight(self, face): |
| """Lightweight face enhancement""" |
| |
| face = cv2.bilateralFilter(face, 9, 75, 75) |
| |
| |
| lab = cv2.cvtColor(face, cv2.COLOR_BGR2LAB) |
| l, a, b = cv2.split(lab) |
| |
| |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
| l = clahe.apply(l) |
| |
| face = cv2.merge([l, a, b]) |
| face = cv2.cvtColor(face, cv2.COLOR_LAB2BGR) |
| |
| |
| kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) / 1 |
| face = cv2.filter2D(face, -1, kernel) |
| |
| return face |
| |
| def enhance_image_pipeline(self, image_path: str, output_path: str = None) -> str: |
| """Complete enhancement pipeline for low VRAM""" |
| print(f"🎨 Enhancing {os.path.basename(image_path)} (Lightweight Mode)...") |
| |
| try: |
| |
| img = cv2.imread(image_path) |
| if img is None: |
| print(f"❌ Failed to load image: {image_path}") |
| return image_path |
| |
| original_shape = img.shape[:2] |
| print(f" Original: {original_shape[1]}x{original_shape[0]}") |
| |
| |
| print(" 🚀 Applying lightweight upscaling (max 2K)...") |
| print(f" 📐 Input: {img.shape[1]}x{img.shape[0]}") |
| enhanced = self.enhance_with_lightweight_esrgan(img) |
| |
| |
| print(" 👤 Enhancing faces...") |
| enhanced = self.enhance_faces_lightweight(enhanced) |
| |
| |
| print(" 🎨 Applying color correction...") |
| enhanced = self.color_correction(enhanced) |
| |
| |
| if output_path is None: |
| output_path = image_path.replace('.', '_enhanced.') |
| |
| cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
| |
| new_shape = enhanced.shape[:2] |
| print(f" ✅ Enhanced: {new_shape[1]}x{new_shape[0]}") |
| |
| |
| self.clear_memory() |
| |
| return output_path |
| |
| except Exception as e: |
| print(f"❌ Pipeline failed: {e}") |
| return image_path |
| |
| def color_correction(self, img): |
| """Lightweight color correction""" |
| |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| l, a, b = cv2.split(lab) |
| |
| |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
| l = clahe.apply(l) |
| |
| |
| a = cv2.convertScaleAbs(a, alpha=1.1, beta=0) |
| b = cv2.convertScaleAbs(b, alpha=1.1, beta=0) |
| |
| |
| enhanced = cv2.merge([l, a, b]) |
| enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) |
| |
| return enhanced |
| |
| def clear_memory(self): |
| """Clear GPU memory""" |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| |
| |
| _lightweight_enhancer = None |
|
|
| def get_lightweight_enhancer(): |
| """Get or create global lightweight enhancer""" |
| global _lightweight_enhancer |
| if _lightweight_enhancer is None: |
| _lightweight_enhancer = LightweightEnhancer() |
| return _lightweight_enhancer |