""" Sprite Image Enhancement Module Uses Real-ESRGAN for high-quality upscaling """ import cv2 import numpy as np import torch from PIL import Image import os class SpriteProcessor: """Processor for enhancing sprite sheet images""" def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = None self._load_model() def _load_model(self): """Load Real-ESRGAN model""" try: from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet # Create model model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) # Initialize Real-ESRGAN model_path = "weights/RealESRGAN_x4plus.pth" if os.path.exists(model_path): self.model = RealESRGANer( scale=4, model_path=model_path, model=model, tile=0, pre_pad=0, half=False, device=self.device ) else: print("Warning: Real-ESRGAN model not found, using fallback enhancement") self.model = None except Exception as e: print(f"Error loading Real-ESRGAN: {e}") self.model = None def enhance_image(self, image: np.ndarray, scale: int = 4) -> np.ndarray: """ Enhance image quality using Real-ESRGAN or fallback methods Args: image: Input image (BGR or BGRA) scale: Upscaling factor (2 or 4) Returns: Enhanced image """ # Handle alpha channel has_alpha = len(image.shape) == 3 and image.shape[2] == 4 if has_alpha: # Separate alpha channel bgr = image[:, :, :3] alpha = image[:, :, 3] else: bgr = image alpha = None # Enhance RGB channels if self.model is not None and scale > 1: try: # Convert BGR to RGB for the model rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) # Apply Real-ESRGAN enhanced_rgb, _ = self.model.enhance(rgb, outscale=scale) # Convert back to BGR enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR) except Exception as e: print(f"Real-ESRGAN failed, using fallback: {e}") enhanced_bgr = self._fallback_enhance(bgr, scale) else: enhanced_bgr = self._fallback_enhance(bgr, scale) # Enhance alpha channel if present if alpha is not None and scale > 1: enhanced_alpha = cv2.resize(alpha, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) # Merge channels enhanced_image = cv2.merge([enhanced_bgr, enhanced_alpha]) else: enhanced_image = enhanced_bgr return enhanced_image def _fallback_enhance(self, image: np.ndarray, scale: int) -> np.ndarray: """ Fallback enhancement using OpenCV Args: image: Input BGR image scale: Upscaling factor Returns: Enhanced image """ # Resize with high-quality interpolation new_width = int(image.shape[1] * scale) new_height = int(image.shape[0] * scale) enhanced = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC) # Apply sharpening kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) enhanced = cv2.filter2D(enhanced, -1, kernel) # Denoise enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 5, 5, 7, 21) return enhanced def sharpen_image(self, image: np.ndarray, strength: float = 1.0) -> np.ndarray: """ Apply sharpening filter Args: image: Input image strength: Sharpening strength Returns: Sharpened image """ kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) * strength sharpened = cv2.filter2D(image, -1, kernel) return sharpened def remove_blur(self, image: np.ndarray) -> np.ndarray: """ Reduce blur using deconvolution Args: image: Input image Returns: Deblurred image """ # Create a point spread function (PSF) psf_size = 5 psf = np.ones((psf_size, psf_size)) / (psf_size ** 2) # Simple deconvolution (Wiener filter approximation) result = image.copy() for i in range(3): # For each channel channel = image[:, :, i].astype(np.float32) / 255.0 # FFT psf_fft = np.fft.fft2(psf, s=channel.shape) channel_fft = np.fft.fft2(channel) # Wiener deconvolution K = 0.01 # Noise to signal ratio deconv_fft = channel_fft * np.conj(psf_fft) / (np.abs(psf_fft) ** 2 + K) # Inverse FFT deconv = np.fft.ifft2(deconv_fft).real # Clip and convert back deconv = np.clip(deconv * 255, 0, 255).astype(np.uint8) result[:, :, i] = deconv return result def enhance_contrast(self, image: np.ndarray) -> np.ndarray: """ Enhance contrast using CLAHE Args: image: Input image Returns: Contrast-enhanced image """ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) l = clahe.apply(l) enhanced = cv2.merge([l, a, b]) enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) return enhanced