Spaces:
Sleeping
Sleeping
File size: 13,106 Bytes
41e2ade f59a1e6 39cf1a9 f59a1e6 41e2ade f59a1e6 41e2ade 39cf1a9 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade 39cf1a9 f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 41e2ade f59a1e6 39cf1a9 f59a1e6 41e2ade 39cf1a9 f59a1e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
import os
import cv2
import numpy
import base64
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel
import uvicorn
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# Create FastAPI app
app = FastAPI(
title="Image Enhancement API",
description="API for enhancing and upscaling images using Real-ESRGAN models",
version="1.0.0"
)
# Add CORS middleware for embedding in other websites
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For production, you may want to restrict this
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create weights directory if it doesn't exist
os.makedirs('weights', exist_ok=True)
# Global variable to track image mode
img_mode = "RGBA"
# Models information
AVAILABLE_MODELS = [
{
"name": "RealESRGAN_x4plus",
"description": "General purpose 4x upscaling model",
"scale": 4
},
{
"name": "RealESRNet_x4plus",
"description": "Alternative 4x upscaling model",
"scale": 4
},
{
"name": "RealESRGAN_x4plus_anime_6B",
"description": "Specialized for anime/cartoon images, 4x upscaling",
"scale": 4
},
{
"name": "RealESRGAN_x2plus",
"description": "2x upscaling model",
"scale": 2
},
{
"name": "realesr-general-x4v3",
"description": "General purpose 4x upscaling model with denoise control",
"scale": 4
}
]
# Pydantic models for API documentation
class HealthResponse(BaseModel):
status: str
message: str
class ModelInfo(BaseModel):
name: str
description: str
scale: int
class ModelsResponse(BaseModel):
models: List[ModelInfo]
class ImageProperties(BaseModel):
width: int
height: int
mode: str
class EnhancementResponse(BaseModel):
enhanced_image: str
properties: ImageProperties
model_used: str
async def process_image(img_data, model_name, denoise_strength, face_enhance, outscale):
"""Real-ESRGAN function to restore (and upscale) images."""
global img_mode
# Define model parameters
if model_name == 'RealESRGAN_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRNet_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]
else:
raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}")
# Download model if not already available
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
# Handle denoise strength for realesr-general-x4v3
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Initialize upsampler
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
gpu_id=None
)
# Initialize face enhancer if needed
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
# Convert input image to CV2 format
if isinstance(img_data, Image.Image):
# Convert PIL Image to numpy array
img_array = numpy.array(img_data)
if img_data.mode == "RGBA":
img_mode = "RGBA"
img = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGRA)
else:
img_mode = "RGB"
img = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
# Already a numpy array
img = img_data
if img.shape[2] == 4:
img_mode = "RGBA"
else:
img_mode = "RGB"
try:
# Process image
if face_enhance:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=outscale)
except RuntimeError as error:
raise HTTPException(status_code=500, detail=f"Processing error: {str(error)}")
# Convert back to appropriate format based on mode
if img_mode == "RGBA":
output_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
else:
output_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
# Get image properties
height, width = output_img.shape[:2]
channels = output_img.shape[2] if len(output_img.shape) > 2 else 1
properties = {
"width": width,
"height": height,
"mode": "RGBA" if channels == 4 else "RGB" if channels == 3 else "Grayscale"
}
return output_img, properties
# Root endpoint for health check - important for Spaces
@app.get("/", response_model=HealthResponse)
async def read_root():
"""Check if the image enhancement API is running."""
return {"status": "ok", "message": "Image Enhancement API is running"}
@app.post("/enhancer", response_model=EnhancementResponse, summary="Enhance and upscale an image")
async def enhance_image(
image: UploadFile = File(..., description="Image file to enhance"),
model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
outscale: int = Form(4, description="Output scale factor"),
face_enhance: bool = Form(False, description="Enable face enhancement")
):
"""
Enhance and upscale an image using Real-ESRGAN models.
- **image**: Upload an image file (PNG, JPG, etc.)
- **model**: Select a model from the available options
- **denoise_strength**: Control the denoising strength (only for realesr-general-x4v3)
- **outscale**: Control the output resolution scaling
- **face_enhance**: Enable face enhancement using GFPGAN
Returns the enhanced image as a base64 string along with image properties.
"""
try:
# Validate model name
valid_models = [m["name"] for m in AVAILABLE_MODELS]
if model not in valid_models:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
)
# Validate other parameters
if not (0 <= denoise_strength <= 1):
raise HTTPException(status_code=400, detail="Denoise strength must be between 0 and 1")
if not (1 <= outscale <= 8):
raise HTTPException(status_code=400, detail="Outscale must be between 1 and 8")
# Read the image file
contents = await image.read()
img = Image.open(BytesIO(contents))
# Process image
output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
# Convert to PIL Image and then to base64
output_pil = Image.fromarray(output_img)
# Save to buffer
buffer = BytesIO()
if properties["mode"] == "RGBA":
output_pil.save(buffer, format="PNG")
else:
output_pil.save(buffer, format="JPEG", quality=95)
# Encode to base64
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
# Return response
return {
"enhanced_image": img_str,
"properties": properties,
"model_used": model
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
# Add a direct image endpoint that returns the actual image instead of base64
@app.post("/enhancer/image", summary="Enhance and return the image directly")
async def enhance_image_direct(
image: UploadFile = File(..., description="Image file to enhance"),
model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
outscale: int = Form(4, description="Output scale factor"),
face_enhance: bool = Form(False, description="Enable face enhancement")
):
"""
Enhance and upscale an image, returning the actual image file directly.
This endpoint works like /enhancer but returns the image directly instead of base64 encoded.
This is useful for direct image display or download.
"""
try:
# Validate model name
valid_models = [m["name"] for m in AVAILABLE_MODELS]
if model not in valid_models:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
)
# Read the image file
contents = await image.read()
img = Image.open(BytesIO(contents))
# Process image
output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
# Convert to PIL Image
output_pil = Image.fromarray(output_img)
# Save to buffer
buffer = BytesIO()
image_format = "PNG" if properties["mode"] == "RGBA" else "JPEG"
if image_format == "PNG":
output_pil.save(buffer, format="PNG")
media_type = "image/png"
else:
output_pil.save(buffer, format="JPEG", quality=95)
media_type = "image/jpeg"
buffer.seek(0)
# Return the image directly
return Response(content=buffer.getvalue(), media_type=media_type)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
@app.get("/health", response_model=HealthResponse, summary="Check server health")
async def health_check():
"""Check if the image enhancement server is running."""
return {"status": "healthy", "message": "Image enhancement server is running"}
@app.get("/models", response_model=ModelsResponse, summary="List available models")
async def list_models():
"""Get a list of all available enhancement models with descriptions."""
return {"models": AVAILABLE_MODELS}
# Add startup event to print server info
@app.on_event("startup")
async def startup_event():
print("๐ Image Enhancement API is starting up!")
print(f"๐ Available models: {', '.join(m['name'] for m in AVAILABLE_MODELS)}")
print("๐ API documentation available at /docs or /redoc")
if __name__ == "__main__":
# Run server with Uvicorn on port 7860 for Hugging Face Spaces
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True) |