Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Optional, Callable | |
| class ImageEnhancer: | |
| """ | |
| AI Image Enhancer using Real-ESRGAN model. | |
| This class handles: | |
| - Automatic model download from Hugging Face Hub | |
| - Image preprocessing and postprocessing | |
| - GPU/CPU inference | |
| - Progress tracking during tile processing | |
| """ | |
| def __init__(self, model_name: str = "RealESRGAN_x4plus"): | |
| self.model_name = model_name | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.tile_size = 256 | |
| self._load_model() | |
| def _load_model(self): | |
| """Download and load the Real-ESRGAN model.""" | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| model_path = Path("weights") | |
| model_path.mkdir(exist_ok=True) | |
| model_file = model_path / "RealESRGAN_x4plus.pth" | |
| if not model_file.exists(): | |
| print("Downloading Real-ESRGAN x4plus model...") | |
| import urllib.request | |
| url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" | |
| urllib.request.urlretrieve(url, model_file) | |
| print("Model downloaded successfully!") | |
| model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=4 | |
| ) | |
| self.upsampler = RealESRGANer( | |
| scale=4, | |
| model_path=str(model_file), | |
| model=model, | |
| tile=self.tile_size, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=False if self.device.type == "cpu" else True, | |
| device=self.device | |
| ) | |
| print(f"Model loaded on {self.device}") | |
| def calculate_tiles(self, width: int, height: int) -> int: | |
| """Calculate the number of tiles for an image.""" | |
| if self.tile_size == 0: | |
| return 1 | |
| tiles_x = max(1, (width + self.tile_size - 1) // self.tile_size) | |
| tiles_y = max(1, (height + self.tile_size - 1) // self.tile_size) | |
| return tiles_x * tiles_y | |
| def enhance(self, image: Image.Image, scale: int = 4, | |
| progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image: | |
| """ | |
| Enhance an image using Real-ESRGAN. | |
| Args: | |
| image: PIL Image to enhance | |
| scale: Upscaling factor (2 or 4) | |
| progress_callback: Optional callback function(progress%, message, current_step, total_steps) | |
| Returns: | |
| Enhanced PIL Image | |
| """ | |
| img_array = np.array(image) | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack([img_array] * 3, axis=-1) | |
| elif img_array.shape[2] == 4: | |
| img_array = img_array[:, :, :3] | |
| img_bgr = img_array[:, :, ::-1] | |
| total_tiles = self.calculate_tiles(image.width, image.height) | |
| if progress_callback: | |
| progress_callback(10.0, "Preprocessing image...", 1, total_tiles + 2) | |
| if progress_callback: | |
| progress_callback(15.0, f"Enhancing image ({total_tiles} tiles)...", 1, total_tiles + 2) | |
| output, _ = self.upsampler.enhance(img_bgr, outscale=scale) | |
| if progress_callback: | |
| progress_callback(90.0, "Postprocessing...", total_tiles + 1, total_tiles + 2) | |
| output_rgb = output[:, :, ::-1] | |
| enhanced_image = Image.fromarray(output_rgb) | |
| if progress_callback: | |
| progress_callback(100.0, "Complete!", total_tiles + 2, total_tiles + 2) | |
| return enhanced_image | |
| class FallbackEnhancer: | |
| """ | |
| Fallback enhancer using traditional image processing when AI model is unavailable. | |
| Uses PIL's high-quality resampling for upscaling. | |
| """ | |
| def __init__(self): | |
| print("Using fallback enhancer (no AI model available)") | |
| def enhance(self, image: Image.Image, scale: int = 4, | |
| progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image: | |
| """ | |
| Enhance image using traditional upscaling with sharpening. | |
| """ | |
| from PIL import ImageEnhance, ImageFilter | |
| if progress_callback: | |
| progress_callback(20.0, "Upscaling image...", 1, 4) | |
| new_size = (image.width * scale, image.height * scale) | |
| upscaled = image.resize(new_size, Image.LANCZOS) | |
| if progress_callback: | |
| progress_callback(50.0, "Applying sharpening...", 2, 4) | |
| enhancer = ImageEnhance.Sharpness(upscaled) | |
| sharpened = enhancer.enhance(1.3) | |
| if progress_callback: | |
| progress_callback(75.0, "Adjusting contrast...", 3, 4) | |
| enhancer = ImageEnhance.Contrast(sharpened) | |
| enhanced = enhancer.enhance(1.1) | |
| if progress_callback: | |
| progress_callback(100.0, "Complete!", 4, 4) | |
| return enhanced | |
| def get_enhancer(): | |
| """ | |
| Factory function to get the best available enhancer. | |
| Returns AI enhancer if available, otherwise falls back to traditional methods. | |
| """ | |
| try: | |
| return ImageEnhancer() | |
| except Exception as e: | |
| print(f"Could not load AI model: {e}") | |
| return FallbackEnhancer() | |