Spaces:
Sleeping
Sleeping
File size: 5,579 Bytes
b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 7adb383 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 d037fc6 b754303 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Optional, Callable
class ImageEnhancer:
"""
AI Image Enhancer using Real-ESRGAN model.
This class handles:
- Automatic model download from Hugging Face Hub
- Image preprocessing and postprocessing
- GPU/CPU inference
- Progress tracking during tile processing
"""
def __init__(self, model_name: str = "RealESRGAN_x4plus"):
self.model_name = model_name
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.tile_size = 256
self._load_model()
def _load_model(self):
"""Download and load the Real-ESRGAN model."""
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
model_path = Path("weights")
model_path.mkdir(exist_ok=True)
model_file = model_path / "RealESRGAN_x4plus.pth"
if not model_file.exists():
print("Downloading Real-ESRGAN x4plus model...")
import urllib.request
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
urllib.request.urlretrieve(url, model_file)
print("Model downloaded successfully!")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
self.upsampler = RealESRGANer(
scale=4,
model_path=str(model_file),
model=model,
tile=self.tile_size,
tile_pad=10,
pre_pad=0,
half=False if self.device.type == "cpu" else True,
device=self.device
)
print(f"Model loaded on {self.device}")
def calculate_tiles(self, width: int, height: int) -> int:
"""Calculate the number of tiles for an image."""
if self.tile_size == 0:
return 1
tiles_x = max(1, (width + self.tile_size - 1) // self.tile_size)
tiles_y = max(1, (height + self.tile_size - 1) // self.tile_size)
return tiles_x * tiles_y
def enhance(self, image: Image.Image, scale: int = 4,
progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image:
"""
Enhance an image using Real-ESRGAN.
Args:
image: PIL Image to enhance
scale: Upscaling factor (2 or 4)
progress_callback: Optional callback function(progress%, message, current_step, total_steps)
Returns:
Enhanced PIL Image
"""
img_array = np.array(image)
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
elif img_array.shape[2] == 4:
img_array = img_array[:, :, :3]
img_bgr = img_array[:, :, ::-1]
total_tiles = self.calculate_tiles(image.width, image.height)
if progress_callback:
progress_callback(10.0, "Preprocessing image...", 1, total_tiles + 2)
if progress_callback:
progress_callback(15.0, f"Enhancing image ({total_tiles} tiles)...", 1, total_tiles + 2)
output, _ = self.upsampler.enhance(img_bgr, outscale=scale)
if progress_callback:
progress_callback(90.0, "Postprocessing...", total_tiles + 1, total_tiles + 2)
output_rgb = output[:, :, ::-1]
enhanced_image = Image.fromarray(output_rgb)
if progress_callback:
progress_callback(100.0, "Complete!", total_tiles + 2, total_tiles + 2)
return enhanced_image
class FallbackEnhancer:
"""
Fallback enhancer using traditional image processing when AI model is unavailable.
Uses PIL's high-quality resampling for upscaling.
"""
def __init__(self):
print("Using fallback enhancer (no AI model available)")
def enhance(self, image: Image.Image, scale: int = 4,
progress_callback: Optional[Callable[[float, str, int, int], None]] = None) -> Image.Image:
"""
Enhance image using traditional upscaling with sharpening.
"""
from PIL import ImageEnhance, ImageFilter
if progress_callback:
progress_callback(20.0, "Upscaling image...", 1, 4)
new_size = (image.width * scale, image.height * scale)
upscaled = image.resize(new_size, Image.LANCZOS)
if progress_callback:
progress_callback(50.0, "Applying sharpening...", 2, 4)
enhancer = ImageEnhance.Sharpness(upscaled)
sharpened = enhancer.enhance(1.3)
if progress_callback:
progress_callback(75.0, "Adjusting contrast...", 3, 4)
enhancer = ImageEnhance.Contrast(sharpened)
enhanced = enhancer.enhance(1.1)
if progress_callback:
progress_callback(100.0, "Complete!", 4, 4)
return enhanced
def get_enhancer():
"""
Factory function to get the best available enhancer.
Returns AI enhancer if available, otherwise falls back to traditional methods.
"""
try:
return ImageEnhancer()
except Exception as e:
print(f"Could not load AI model: {e}")
return FallbackEnhancer()
|