Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import Optional, List, Dict, Any, Union | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils.download_util import load_file_from_url | |
| from realesrgan import RealESRGANer | |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Image Enhancement API", | |
| description="API for enhancing and upscaling images using Real-ESRGAN models", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware for embedding in other websites | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For production, you may want to restrict this | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Create weights directory if it doesn't exist | |
| os.makedirs('weights', exist_ok=True) | |
| # Global variable to track image mode | |
| img_mode = "RGBA" | |
| # Models information | |
| AVAILABLE_MODELS = [ | |
| { | |
| "name": "RealESRGAN_x4plus", | |
| "description": "General purpose 4x upscaling model", | |
| "scale": 4 | |
| }, | |
| { | |
| "name": "RealESRNet_x4plus", | |
| "description": "Alternative 4x upscaling model", | |
| "scale": 4 | |
| }, | |
| { | |
| "name": "RealESRGAN_x4plus_anime_6B", | |
| "description": "Specialized for anime/cartoon images, 4x upscaling", | |
| "scale": 4 | |
| }, | |
| { | |
| "name": "RealESRGAN_x2plus", | |
| "description": "2x upscaling model", | |
| "scale": 2 | |
| }, | |
| { | |
| "name": "realesr-general-x4v3", | |
| "description": "General purpose 4x upscaling model with denoise control", | |
| "scale": 4 | |
| } | |
| ] | |
| # Pydantic models for API documentation | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| class ModelInfo(BaseModel): | |
| name: str | |
| description: str | |
| scale: int | |
| class ModelsResponse(BaseModel): | |
| models: List[ModelInfo] | |
| class ImageProperties(BaseModel): | |
| width: int | |
| height: int | |
| mode: str | |
| class EnhancementResponse(BaseModel): | |
| enhanced_image: str | |
| properties: ImageProperties | |
| model_used: str | |
| async def process_image(img_data, model_name, denoise_strength, face_enhance, outscale): | |
| """Real-ESRGAN function to restore (and upscale) images.""" | |
| global img_mode | |
| # Define model parameters | |
| if model_name == 'RealESRGAN_x4plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] | |
| elif model_name == 'RealESRNet_x4plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] | |
| elif model_name == 'RealESRGAN_x4plus_anime_6B': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] | |
| elif model_name == 'RealESRGAN_x2plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| netscale = 2 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] | |
| elif model_name == 'realesr-general-x4v3': | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| netscale = 4 | |
| file_url = [ | |
| 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', | |
| 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' | |
| ] | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}") | |
| # Download model if not already available | |
| model_path = os.path.join('weights', model_name + '.pth') | |
| if not os.path.isfile(model_path): | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| for url in file_url: | |
| model_path = load_file_from_url( | |
| url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) | |
| # Handle denoise strength for realesr-general-x4v3 | |
| dni_weight = None | |
| if model_name == 'realesr-general-x4v3' and denoise_strength != 1: | |
| wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') | |
| model_path = [model_path, wdn_model_path] | |
| dni_weight = [denoise_strength, 1 - denoise_strength] | |
| # Initialize upsampler | |
| upsampler = RealESRGANer( | |
| scale=netscale, | |
| model_path=model_path, | |
| dni_weight=dni_weight, | |
| model=model, | |
| tile=0, | |
| tile_pad=10, | |
| pre_pad=10, | |
| half=False, | |
| gpu_id=None | |
| ) | |
| # Initialize face enhancer if needed | |
| if face_enhance: | |
| from gfpgan import GFPGANer | |
| face_enhancer = GFPGANer( | |
| model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| upscale=outscale, | |
| arch='clean', | |
| channel_multiplier=2, | |
| bg_upsampler=upsampler) | |
| # Convert input image to CV2 format | |
| if isinstance(img_data, Image.Image): | |
| # Convert PIL Image to numpy array | |
| img_array = numpy.array(img_data) | |
| if img_data.mode == "RGBA": | |
| img_mode = "RGBA" | |
| img = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGRA) | |
| else: | |
| img_mode = "RGB" | |
| img = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
| else: | |
| # Already a numpy array | |
| img = img_data | |
| if img.shape[2] == 4: | |
| img_mode = "RGBA" | |
| else: | |
| img_mode = "RGB" | |
| try: | |
| # Process image | |
| if face_enhance: | |
| _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
| else: | |
| output, _ = upsampler.enhance(img, outscale=outscale) | |
| except RuntimeError as error: | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(error)}") | |
| # Convert back to appropriate format based on mode | |
| if img_mode == "RGBA": | |
| output_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) | |
| else: | |
| output_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
| # Get image properties | |
| height, width = output_img.shape[:2] | |
| channels = output_img.shape[2] if len(output_img.shape) > 2 else 1 | |
| properties = { | |
| "width": width, | |
| "height": height, | |
| "mode": "RGBA" if channels == 4 else "RGB" if channels == 3 else "Grayscale" | |
| } | |
| return output_img, properties | |
| # Root endpoint for health check - important for Spaces | |
| async def read_root(): | |
| """Check if the image enhancement API is running.""" | |
| return {"status": "ok", "message": "Image Enhancement API is running"} | |
| async def enhance_image( | |
| image: UploadFile = File(..., description="Image file to enhance"), | |
| model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"), | |
| denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"), | |
| outscale: int = Form(4, description="Output scale factor"), | |
| face_enhance: bool = Form(False, description="Enable face enhancement") | |
| ): | |
| """ | |
| Enhance and upscale an image using Real-ESRGAN models. | |
| - **image**: Upload an image file (PNG, JPG, etc.) | |
| - **model**: Select a model from the available options | |
| - **denoise_strength**: Control the denoising strength (only for realesr-general-x4v3) | |
| - **outscale**: Control the output resolution scaling | |
| - **face_enhance**: Enable face enhancement using GFPGAN | |
| Returns the enhanced image as a base64 string along with image properties. | |
| """ | |
| try: | |
| # Validate model name | |
| valid_models = [m["name"] for m in AVAILABLE_MODELS] | |
| if model not in valid_models: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model. Choose from: {', '.join(valid_models)}" | |
| ) | |
| # Validate other parameters | |
| if not (0 <= denoise_strength <= 1): | |
| raise HTTPException(status_code=400, detail="Denoise strength must be between 0 and 1") | |
| if not (1 <= outscale <= 8): | |
| raise HTTPException(status_code=400, detail="Outscale must be between 1 and 8") | |
| # Read the image file | |
| contents = await image.read() | |
| img = Image.open(BytesIO(contents)) | |
| # Process image | |
| output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale) | |
| # Convert to PIL Image and then to base64 | |
| output_pil = Image.fromarray(output_img) | |
| # Save to buffer | |
| buffer = BytesIO() | |
| if properties["mode"] == "RGBA": | |
| output_pil.save(buffer, format="PNG") | |
| else: | |
| output_pil.save(buffer, format="JPEG", quality=95) | |
| # Encode to base64 | |
| img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| # Return response | |
| return { | |
| "enhanced_image": img_str, | |
| "properties": properties, | |
| "model_used": model | |
| } | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Server error: {str(e)}") | |
| # Add a direct image endpoint that returns the actual image instead of base64 | |
| async def enhance_image_direct( | |
| image: UploadFile = File(..., description="Image file to enhance"), | |
| model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"), | |
| denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"), | |
| outscale: int = Form(4, description="Output scale factor"), | |
| face_enhance: bool = Form(False, description="Enable face enhancement") | |
| ): | |
| """ | |
| Enhance and upscale an image, returning the actual image file directly. | |
| This endpoint works like /enhancer but returns the image directly instead of base64 encoded. | |
| This is useful for direct image display or download. | |
| """ | |
| try: | |
| # Validate model name | |
| valid_models = [m["name"] for m in AVAILABLE_MODELS] | |
| if model not in valid_models: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model. Choose from: {', '.join(valid_models)}" | |
| ) | |
| # Read the image file | |
| contents = await image.read() | |
| img = Image.open(BytesIO(contents)) | |
| # Process image | |
| output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale) | |
| # Convert to PIL Image | |
| output_pil = Image.fromarray(output_img) | |
| # Save to buffer | |
| buffer = BytesIO() | |
| image_format = "PNG" if properties["mode"] == "RGBA" else "JPEG" | |
| if image_format == "PNG": | |
| output_pil.save(buffer, format="PNG") | |
| media_type = "image/png" | |
| else: | |
| output_pil.save(buffer, format="JPEG", quality=95) | |
| media_type = "image/jpeg" | |
| buffer.seek(0) | |
| # Return the image directly | |
| return Response(content=buffer.getvalue(), media_type=media_type) | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Server error: {str(e)}") | |
| async def health_check(): | |
| """Check if the image enhancement server is running.""" | |
| return {"status": "healthy", "message": "Image enhancement server is running"} | |
| async def list_models(): | |
| """Get a list of all available enhancement models with descriptions.""" | |
| return {"models": AVAILABLE_MODELS} | |
| # Add startup event to print server info | |
| async def startup_event(): | |
| print("๐ Image Enhancement API is starting up!") | |
| print(f"๐ Available models: {', '.join(m['name'] for m in AVAILABLE_MODELS)}") | |
| print("๐ API documentation available at /docs or /redoc") | |
| if __name__ == "__main__": | |
| # Run server with Uvicorn on port 7860 for Hugging Face Spaces | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True) |