import os import cv2 import numpy import base64 from io import BytesIO from PIL import Image from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse, Response from fastapi.middleware.cors import CORSMiddleware from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel import uvicorn from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact # Create FastAPI app app = FastAPI( title="Image Enhancement API", description="API for enhancing and upscaling images using Real-ESRGAN models", version="1.0.0" ) # Add CORS middleware for embedding in other websites app.add_middleware( CORSMiddleware, allow_origins=["*"], # For production, you may want to restrict this allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Create weights directory if it doesn't exist os.makedirs('weights', exist_ok=True) # Global variable to track image mode img_mode = "RGBA" # Models information AVAILABLE_MODELS = [ { "name": "RealESRGAN_x4plus", "description": "General purpose 4x upscaling model", "scale": 4 }, { "name": "RealESRNet_x4plus", "description": "Alternative 4x upscaling model", "scale": 4 }, { "name": "RealESRGAN_x4plus_anime_6B", "description": "Specialized for anime/cartoon images, 4x upscaling", "scale": 4 }, { "name": "RealESRGAN_x2plus", "description": "2x upscaling model", "scale": 2 }, { "name": "realesr-general-x4v3", "description": "General purpose 4x upscaling model with denoise control", "scale": 4 } ] # Pydantic models for API documentation class HealthResponse(BaseModel): status: str message: str class ModelInfo(BaseModel): name: str description: str scale: int class ModelsResponse(BaseModel): models: List[ModelInfo] class ImageProperties(BaseModel): width: int height: int mode: str class EnhancementResponse(BaseModel): enhanced_image: str properties: ImageProperties model_used: str async def process_image(img_data, model_name, denoise_strength, face_enhance, outscale): """Real-ESRGAN function to restore (and upscale) images.""" global img_mode # Define model parameters if model_name == 'RealESRGAN_x4plus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] elif model_name == 'RealESRNet_x4plus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] elif model_name == 'RealESRGAN_x4plus_anime_6B': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] elif model_name == 'RealESRGAN_x2plus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) netscale = 2 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] elif model_name == 'realesr-general-x4v3': model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') netscale = 4 file_url = [ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' ] else: raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}") # Download model if not already available model_path = os.path.join('weights', model_name + '.pth') if not os.path.isfile(model_path): ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) for url in file_url: model_path = load_file_from_url( url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) # Handle denoise strength for realesr-general-x4v3 dni_weight = None if model_name == 'realesr-general-x4v3' and denoise_strength != 1: wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') model_path = [model_path, wdn_model_path] dni_weight = [denoise_strength, 1 - denoise_strength] # Initialize upsampler upsampler = RealESRGANer( scale=netscale, model_path=model_path, dni_weight=dni_weight, model=model, tile=0, tile_pad=10, pre_pad=10, half=False, gpu_id=None ) # Initialize face enhancer if needed if face_enhance: from gfpgan import GFPGANer face_enhancer = GFPGANer( model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) # Convert input image to CV2 format if isinstance(img_data, Image.Image): # Convert PIL Image to numpy array img_array = numpy.array(img_data) if img_data.mode == "RGBA": img_mode = "RGBA" img = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGRA) else: img_mode = "RGB" img = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) else: # Already a numpy array img = img_data if img.shape[2] == 4: img_mode = "RGBA" else: img_mode = "RGB" try: # Process image if face_enhance: _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) else: output, _ = upsampler.enhance(img, outscale=outscale) except RuntimeError as error: raise HTTPException(status_code=500, detail=f"Processing error: {str(error)}") # Convert back to appropriate format based on mode if img_mode == "RGBA": output_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) else: output_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) # Get image properties height, width = output_img.shape[:2] channels = output_img.shape[2] if len(output_img.shape) > 2 else 1 properties = { "width": width, "height": height, "mode": "RGBA" if channels == 4 else "RGB" if channels == 3 else "Grayscale" } return output_img, properties # Root endpoint for health check - important for Spaces @app.get("/", response_model=HealthResponse) async def read_root(): """Check if the image enhancement API is running.""" return {"status": "ok", "message": "Image Enhancement API is running"} @app.post("/enhancer", response_model=EnhancementResponse, summary="Enhance and upscale an image") async def enhance_image( image: UploadFile = File(..., description="Image file to enhance"), model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"), denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"), outscale: int = Form(4, description="Output scale factor"), face_enhance: bool = Form(False, description="Enable face enhancement") ): """ Enhance and upscale an image using Real-ESRGAN models. - **image**: Upload an image file (PNG, JPG, etc.) - **model**: Select a model from the available options - **denoise_strength**: Control the denoising strength (only for realesr-general-x4v3) - **outscale**: Control the output resolution scaling - **face_enhance**: Enable face enhancement using GFPGAN Returns the enhanced image as a base64 string along with image properties. """ try: # Validate model name valid_models = [m["name"] for m in AVAILABLE_MODELS] if model not in valid_models: raise HTTPException( status_code=400, detail=f"Invalid model. Choose from: {', '.join(valid_models)}" ) # Validate other parameters if not (0 <= denoise_strength <= 1): raise HTTPException(status_code=400, detail="Denoise strength must be between 0 and 1") if not (1 <= outscale <= 8): raise HTTPException(status_code=400, detail="Outscale must be between 1 and 8") # Read the image file contents = await image.read() img = Image.open(BytesIO(contents)) # Process image output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale) # Convert to PIL Image and then to base64 output_pil = Image.fromarray(output_img) # Save to buffer buffer = BytesIO() if properties["mode"] == "RGBA": output_pil.save(buffer, format="PNG") else: output_pil.save(buffer, format="JPEG", quality=95) # Encode to base64 img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') # Return response return { "enhanced_image": img_str, "properties": properties, "model_used": model } except HTTPException as e: raise e except Exception as e: raise HTTPException(status_code=500, detail=f"Server error: {str(e)}") # Add a direct image endpoint that returns the actual image instead of base64 @app.post("/enhancer/image", summary="Enhance and return the image directly") async def enhance_image_direct( image: UploadFile = File(..., description="Image file to enhance"), model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"), denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"), outscale: int = Form(4, description="Output scale factor"), face_enhance: bool = Form(False, description="Enable face enhancement") ): """ Enhance and upscale an image, returning the actual image file directly. This endpoint works like /enhancer but returns the image directly instead of base64 encoded. This is useful for direct image display or download. """ try: # Validate model name valid_models = [m["name"] for m in AVAILABLE_MODELS] if model not in valid_models: raise HTTPException( status_code=400, detail=f"Invalid model. Choose from: {', '.join(valid_models)}" ) # Read the image file contents = await image.read() img = Image.open(BytesIO(contents)) # Process image output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale) # Convert to PIL Image output_pil = Image.fromarray(output_img) # Save to buffer buffer = BytesIO() image_format = "PNG" if properties["mode"] == "RGBA" else "JPEG" if image_format == "PNG": output_pil.save(buffer, format="PNG") media_type = "image/png" else: output_pil.save(buffer, format="JPEG", quality=95) media_type = "image/jpeg" buffer.seek(0) # Return the image directly return Response(content=buffer.getvalue(), media_type=media_type) except HTTPException as e: raise e except Exception as e: raise HTTPException(status_code=500, detail=f"Server error: {str(e)}") @app.get("/health", response_model=HealthResponse, summary="Check server health") async def health_check(): """Check if the image enhancement server is running.""" return {"status": "healthy", "message": "Image enhancement server is running"} @app.get("/models", response_model=ModelsResponse, summary="List available models") async def list_models(): """Get a list of all available enhancement models with descriptions.""" return {"models": AVAILABLE_MODELS} # Add startup event to print server info @app.on_event("startup") async def startup_event(): print("🚀 Image Enhancement API is starting up!") print(f"📚 Available models: {', '.join(m['name'] for m in AVAILABLE_MODELS)}") print("📋 API documentation available at /docs or /redoc") if __name__ == "__main__": # Run server with Uvicorn on port 7860 for Hugging Face Spaces port = int(os.environ.get("PORT", 7860)) uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)