Spaces:
Paused
Paused
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from gradio_client import Client, handle_file | |
| from fastapi.responses import JSONResponse | |
| import tempfile | |
| import os | |
| import uuid | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| from concurrent.futures import ThreadPoolExecutor | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Gradio API URL | |
| GRADIO_API_URL = "jallenjia/Change-Clothes-AI" | |
| # Thread pool for Gradio API calls (to avoid blocking async loop) | |
| executor = ThreadPoolExecutor(max_workers=None) # Changed from max_workers=10 to max_workers=None to remove limit | |
| # Context manager for temporary files | |
| async def temp_file_manager(file_content: bytes, suffix: str = ".png"): | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=f"tryon_{uuid.uuid4()}_") | |
| try: | |
| temp_file.write(file_content) | |
| temp_file.close() | |
| yield temp_file.name | |
| finally: | |
| try: | |
| if os.path.exists(temp_file.name): | |
| os.unlink(temp_file.name) | |
| except Exception as e: | |
| logger.error(f"Failed to delete temp file {temp_file.name}: {e}") | |
| # Run Gradio API call in thread pool to avoid blocking | |
| async def run_gradio_predict( | |
| background_path: str, | |
| garm_img_path: str, | |
| garment_des: str, | |
| is_checked: bool, | |
| is_checked_crop: bool, | |
| denoise_steps: int, | |
| seed: int, | |
| category: str | |
| ): | |
| loop = asyncio.get_event_loop() | |
| try: | |
| client = Client(GRADIO_API_URL) | |
| result = await loop.run_in_executor( | |
| executor, | |
| lambda: client.predict( | |
| dict={ | |
| "background": handle_file(background_path), | |
| "layers": [], | |
| "composite": None | |
| }, | |
| garm_img=handle_file(garm_img_path), | |
| garment_des=garment_des, | |
| is_checked=is_checked, | |
| is_checked_crop=is_checked_crop, | |
| denoise_steps=denoise_steps, | |
| seed=seed, | |
| category=category, | |
| api_name="/tryon" | |
| ) | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Gradio API error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Gradio API error: {str(e)}") | |
| async def tryon( | |
| background: UploadFile = File(...), | |
| garm_img: UploadFile = File(...), | |
| garment_des: str = Form("navy blue polo shirt"), | |
| is_checked: bool = Form(True), | |
| is_checked_crop: bool = Form(False), | |
| denoise_steps: int = Form(30), | |
| seed: int = Form(42), | |
| category: str = Form("upper_body") | |
| ): | |
| try: | |
| # Validate file types | |
| if not background.content_type.startswith("image/") or not garm_img.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Only image files are allowed") | |
| # Read file contents | |
| background_content = await background.read() | |
| garm_img_content = await garm_img.read() | |
| # Create temporary files with unique names | |
| async with temp_file_manager(background_content, ".png") as background_path: | |
| async with temp_file_manager(garm_img_content, ".png") as garm_img_path: | |
| # Call Gradio API in thread pool | |
| result = await run_gradio_predict( | |
| background_path=background_path, | |
| garm_img_path=garm_img_path, | |
| garment_des=garment_des, | |
| is_checked=is_checked, | |
| is_checked_crop=is_checked_crop, | |
| denoise_steps=denoise_steps, | |
| seed=seed, | |
| category=category | |
| ) | |
| return JSONResponse(content={"result": str(result)}) | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| logger.error(f"Request failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| # Shutdown thread pool gracefully | |
| def shutdown_event(): | |
| executor.shutdown(wait=True) | |
| logger.info("Thread pool shut down") |