IMG_Enhancer / app.py
um41r's picture
Update app.py
39cf1a9 verified
import os
import cv2
import numpy
import base64
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel
import uvicorn
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# Create FastAPI app
app = FastAPI(
title="Image Enhancement API",
description="API for enhancing and upscaling images using Real-ESRGAN models",
version="1.0.0"
)
# Add CORS middleware for embedding in other websites
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For production, you may want to restrict this
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create weights directory if it doesn't exist
os.makedirs('weights', exist_ok=True)
# Global variable to track image mode
img_mode = "RGBA"
# Models information
AVAILABLE_MODELS = [
{
"name": "RealESRGAN_x4plus",
"description": "General purpose 4x upscaling model",
"scale": 4
},
{
"name": "RealESRNet_x4plus",
"description": "Alternative 4x upscaling model",
"scale": 4
},
{
"name": "RealESRGAN_x4plus_anime_6B",
"description": "Specialized for anime/cartoon images, 4x upscaling",
"scale": 4
},
{
"name": "RealESRGAN_x2plus",
"description": "2x upscaling model",
"scale": 2
},
{
"name": "realesr-general-x4v3",
"description": "General purpose 4x upscaling model with denoise control",
"scale": 4
}
]
# Pydantic models for API documentation
class HealthResponse(BaseModel):
status: str
message: str
class ModelInfo(BaseModel):
name: str
description: str
scale: int
class ModelsResponse(BaseModel):
models: List[ModelInfo]
class ImageProperties(BaseModel):
width: int
height: int
mode: str
class EnhancementResponse(BaseModel):
enhanced_image: str
properties: ImageProperties
model_used: str
async def process_image(img_data, model_name, denoise_strength, face_enhance, outscale):
"""Real-ESRGAN function to restore (and upscale) images."""
global img_mode
# Define model parameters
if model_name == 'RealESRGAN_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRNet_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]
else:
raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}")
# Download model if not already available
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
# Handle denoise strength for realesr-general-x4v3
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Initialize upsampler
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
gpu_id=None
)
# Initialize face enhancer if needed
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
# Convert input image to CV2 format
if isinstance(img_data, Image.Image):
# Convert PIL Image to numpy array
img_array = numpy.array(img_data)
if img_data.mode == "RGBA":
img_mode = "RGBA"
img = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGRA)
else:
img_mode = "RGB"
img = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
# Already a numpy array
img = img_data
if img.shape[2] == 4:
img_mode = "RGBA"
else:
img_mode = "RGB"
try:
# Process image
if face_enhance:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=outscale)
except RuntimeError as error:
raise HTTPException(status_code=500, detail=f"Processing error: {str(error)}")
# Convert back to appropriate format based on mode
if img_mode == "RGBA":
output_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
else:
output_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
# Get image properties
height, width = output_img.shape[:2]
channels = output_img.shape[2] if len(output_img.shape) > 2 else 1
properties = {
"width": width,
"height": height,
"mode": "RGBA" if channels == 4 else "RGB" if channels == 3 else "Grayscale"
}
return output_img, properties
# Root endpoint for health check - important for Spaces
@app.get("/", response_model=HealthResponse)
async def read_root():
"""Check if the image enhancement API is running."""
return {"status": "ok", "message": "Image Enhancement API is running"}
@app.post("/enhancer", response_model=EnhancementResponse, summary="Enhance and upscale an image")
async def enhance_image(
image: UploadFile = File(..., description="Image file to enhance"),
model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
outscale: int = Form(4, description="Output scale factor"),
face_enhance: bool = Form(False, description="Enable face enhancement")
):
"""
Enhance and upscale an image using Real-ESRGAN models.
- **image**: Upload an image file (PNG, JPG, etc.)
- **model**: Select a model from the available options
- **denoise_strength**: Control the denoising strength (only for realesr-general-x4v3)
- **outscale**: Control the output resolution scaling
- **face_enhance**: Enable face enhancement using GFPGAN
Returns the enhanced image as a base64 string along with image properties.
"""
try:
# Validate model name
valid_models = [m["name"] for m in AVAILABLE_MODELS]
if model not in valid_models:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
)
# Validate other parameters
if not (0 <= denoise_strength <= 1):
raise HTTPException(status_code=400, detail="Denoise strength must be between 0 and 1")
if not (1 <= outscale <= 8):
raise HTTPException(status_code=400, detail="Outscale must be between 1 and 8")
# Read the image file
contents = await image.read()
img = Image.open(BytesIO(contents))
# Process image
output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
# Convert to PIL Image and then to base64
output_pil = Image.fromarray(output_img)
# Save to buffer
buffer = BytesIO()
if properties["mode"] == "RGBA":
output_pil.save(buffer, format="PNG")
else:
output_pil.save(buffer, format="JPEG", quality=95)
# Encode to base64
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
# Return response
return {
"enhanced_image": img_str,
"properties": properties,
"model_used": model
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
# Add a direct image endpoint that returns the actual image instead of base64
@app.post("/enhancer/image", summary="Enhance and return the image directly")
async def enhance_image_direct(
image: UploadFile = File(..., description="Image file to enhance"),
model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
outscale: int = Form(4, description="Output scale factor"),
face_enhance: bool = Form(False, description="Enable face enhancement")
):
"""
Enhance and upscale an image, returning the actual image file directly.
This endpoint works like /enhancer but returns the image directly instead of base64 encoded.
This is useful for direct image display or download.
"""
try:
# Validate model name
valid_models = [m["name"] for m in AVAILABLE_MODELS]
if model not in valid_models:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
)
# Read the image file
contents = await image.read()
img = Image.open(BytesIO(contents))
# Process image
output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
# Convert to PIL Image
output_pil = Image.fromarray(output_img)
# Save to buffer
buffer = BytesIO()
image_format = "PNG" if properties["mode"] == "RGBA" else "JPEG"
if image_format == "PNG":
output_pil.save(buffer, format="PNG")
media_type = "image/png"
else:
output_pil.save(buffer, format="JPEG", quality=95)
media_type = "image/jpeg"
buffer.seek(0)
# Return the image directly
return Response(content=buffer.getvalue(), media_type=media_type)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
@app.get("/health", response_model=HealthResponse, summary="Check server health")
async def health_check():
"""Check if the image enhancement server is running."""
return {"status": "healthy", "message": "Image enhancement server is running"}
@app.get("/models", response_model=ModelsResponse, summary="List available models")
async def list_models():
"""Get a list of all available enhancement models with descriptions."""
return {"models": AVAILABLE_MODELS}
# Add startup event to print server info
@app.on_event("startup")
async def startup_event():
print("๐Ÿš€ Image Enhancement API is starting up!")
print(f"๐Ÿ“š Available models: {', '.join(m['name'] for m in AVAILABLE_MODELS)}")
print("๐Ÿ“‹ API documentation available at /docs or /redoc")
if __name__ == "__main__":
# Run server with Uvicorn on port 7860 for Hugging Face Spaces
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)