import torch import numpy as np from PIL import Image from pathlib import Path from typing import Optional, Callable class ImageEnhancer: """ AI Image Enhancer using Real-ESRGAN model. This class handles: - Automatic model download from Hugging Face Hub - Image preprocessing and postprocessing - GPU/CPU inference - Progress tracking during tile processing """ def __init__(self, model_name: str = "RealESRGAN_x4plus"): self.model_name = model_name self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.tile_size = 256 self._load_model() def _load_model(self): """Download and load the Real-ESRGAN model.""" from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet model_path = Path("weights") model_path.mkdir(exist_ok=True) model_file = model_path / "RealESRGAN_x4plus.pth" if not model_file.exists(): print("Downloading Real-ESRGAN x4plus model...") import urllib.request url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" urllib.request.urlretrieve(url, model_file) print("Model downloaded successfully!") model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) self.upsampler = RealESRGANer( scale=4, model_path=str(model_file), model=model, tile=self.tile_size, tile_pad=10, pre_pad=0, half=False if self.device.type == "cpu" else True, device=self.device ) print(f"Model loaded on {self.device}") def calculate_tiles(self, width: int, height: int) -> int: """Calculate the number of tiles for an image.""" if self.tile_size == 0: return 1 tiles_x = max(1, (width + self.tile_size - 1) // self.tile_size) tiles_y = max(1, (height + self.tile_size - 1) // self.tile_size) return tiles_x * tiles_y def enhance(self, image: Image.Image, scale: int = 4, progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image: """ Enhance an image using Real-ESRGAN. Args: image: PIL Image to enhance scale: Upscaling factor (2 or 4) progress_callback: Optional callback function(progress%, message, current_step, total_steps) Returns: Enhanced PIL Image """ img_array = np.array(image) if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) elif img_array.shape[2] == 4: img_array = img_array[:, :, :3] img_bgr = img_array[:, :, ::-1] total_tiles = self.calculate_tiles(image.width, image.height) if progress_callback: progress_callback(10.0, "Preprocessing image...", 1, total_tiles + 2) if progress_callback: progress_callback(15.0, f"Enhancing image ({total_tiles} tiles)...", 1, total_tiles + 2) output, _ = self.upsampler.enhance(img_bgr, outscale=scale) if progress_callback: progress_callback(90.0, "Postprocessing...", total_tiles + 1, total_tiles + 2) output_rgb = output[:, :, ::-1] enhanced_image = Image.fromarray(output_rgb) if progress_callback: progress_callback(100.0, "Complete!", total_tiles + 2, total_tiles + 2) return enhanced_image class FallbackEnhancer: """ Fallback enhancer using traditional image processing when AI model is unavailable. Uses PIL's high-quality resampling for upscaling. """ def __init__(self): print("Using fallback enhancer (no AI model available)") def enhance(self, image: Image.Image, scale: int = 4, progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image: """ Enhance image using traditional upscaling with sharpening. """ from PIL import ImageEnhance, ImageFilter if progress_callback: progress_callback(20.0, "Upscaling image...", 1, 4) new_size = (image.width * scale, image.height * scale) upscaled = image.resize(new_size, Image.LANCZOS) if progress_callback: progress_callback(50.0, "Applying sharpening...", 2, 4) enhancer = ImageEnhance.Sharpness(upscaled) sharpened = enhancer.enhance(1.3) if progress_callback: progress_callback(75.0, "Adjusting contrast...", 3, 4) enhancer = ImageEnhance.Contrast(sharpened) enhanced = enhancer.enhance(1.1) if progress_callback: progress_callback(100.0, "Complete!", 4, 4) return enhanced def get_enhancer(): """ Factory function to get the best available enhancer. Returns AI enhancer if available, otherwise falls back to traditional methods. """ try: return ImageEnhancer() except Exception as e: print(f"Could not load AI model: {e}") return FallbackEnhancer()