Imgenhance / enhancer.py
m9jaex
Add progress tracking to image enhancement processes
d037fc6
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Optional, Callable
class ImageEnhancer:
"""
AI Image Enhancer using Real-ESRGAN model.
This class handles:
- Automatic model download from Hugging Face Hub
- Image preprocessing and postprocessing
- GPU/CPU inference
- Progress tracking during tile processing
"""
def __init__(self, model_name: str = "RealESRGAN_x4plus"):
self.model_name = model_name
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.tile_size = 256
self._load_model()
def _load_model(self):
"""Download and load the Real-ESRGAN model."""
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
model_path = Path("weights")
model_path.mkdir(exist_ok=True)
model_file = model_path / "RealESRGAN_x4plus.pth"
if not model_file.exists():
print("Downloading Real-ESRGAN x4plus model...")
import urllib.request
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
urllib.request.urlretrieve(url, model_file)
print("Model downloaded successfully!")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
self.upsampler = RealESRGANer(
scale=4,
model_path=str(model_file),
model=model,
tile=self.tile_size,
tile_pad=10,
pre_pad=0,
half=False if self.device.type == "cpu" else True,
device=self.device
)
print(f"Model loaded on {self.device}")
def calculate_tiles(self, width: int, height: int) -> int:
"""Calculate the number of tiles for an image."""
if self.tile_size == 0:
return 1
tiles_x = max(1, (width + self.tile_size - 1) // self.tile_size)
tiles_y = max(1, (height + self.tile_size - 1) // self.tile_size)
return tiles_x * tiles_y
def enhance(self, image: Image.Image, scale: int = 4,
progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image:
"""
Enhance an image using Real-ESRGAN.
Args:
image: PIL Image to enhance
scale: Upscaling factor (2 or 4)
progress_callback: Optional callback function(progress%, message, current_step, total_steps)
Returns:
Enhanced PIL Image
"""
img_array = np.array(image)
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
elif img_array.shape[2] == 4:
img_array = img_array[:, :, :3]
img_bgr = img_array[:, :, ::-1]
total_tiles = self.calculate_tiles(image.width, image.height)
if progress_callback:
progress_callback(10.0, "Preprocessing image...", 1, total_tiles + 2)
if progress_callback:
progress_callback(15.0, f"Enhancing image ({total_tiles} tiles)...", 1, total_tiles + 2)
output, _ = self.upsampler.enhance(img_bgr, outscale=scale)
if progress_callback:
progress_callback(90.0, "Postprocessing...", total_tiles + 1, total_tiles + 2)
output_rgb = output[:, :, ::-1]
enhanced_image = Image.fromarray(output_rgb)
if progress_callback:
progress_callback(100.0, "Complete!", total_tiles + 2, total_tiles + 2)
return enhanced_image
class FallbackEnhancer:
"""
Fallback enhancer using traditional image processing when AI model is unavailable.
Uses PIL's high-quality resampling for upscaling.
"""
def __init__(self):
print("Using fallback enhancer (no AI model available)")
def enhance(self, image: Image.Image, scale: int = 4,
progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image:
"""
Enhance image using traditional upscaling with sharpening.
"""
from PIL import ImageEnhance, ImageFilter
if progress_callback:
progress_callback(20.0, "Upscaling image...", 1, 4)
new_size = (image.width * scale, image.height * scale)
upscaled = image.resize(new_size, Image.LANCZOS)
if progress_callback:
progress_callback(50.0, "Applying sharpening...", 2, 4)
enhancer = ImageEnhance.Sharpness(upscaled)
sharpened = enhancer.enhance(1.3)
if progress_callback:
progress_callback(75.0, "Adjusting contrast...", 3, 4)
enhancer = ImageEnhance.Contrast(sharpened)
enhanced = enhancer.enhance(1.1)
if progress_callback:
progress_callback(100.0, "Complete!", 4, 4)
return enhanced
def get_enhancer():
"""
Factory function to get the best available enhancer.
Returns AI enhancer if available, otherwise falls back to traditional methods.
"""
try:
return ImageEnhancer()
except Exception as e:
print(f"Could not load AI model: {e}")
return FallbackEnhancer()