Spaces:
Sleeping
Sleeping
| import pathlib | |
| import pickle | |
| import shutil | |
| import tempfile | |
| import threading | |
| import time | |
| import httpx | |
| from fastapi import ( | |
| Depends, | |
| FastAPI, | |
| File, | |
| Form, | |
| Header, | |
| HTTPException, | |
| UploadFile, | |
| status, | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from logger import logger | |
| import common | |
| from config import settings | |
| from sr_queue import listen_distributed_queue, listen_queue | |
| async def verify_token(x_token: str = Header()): | |
| if x_token != settings.get("token"): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid X-Token" | |
| ) | |
| app = FastAPI( | |
| dependencies=[Depends(verify_token)], | |
| title="Super Resolution API", | |
| description="Super Resolution API for Anime and Illustration", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def register_routes(): | |
| async def root(): | |
| return { | |
| "message": f"Super Resolution API is running as {settings.get('mode', 'single')} mode" | |
| } | |
| async def get_result(task_id: str): | |
| result = common.redis_client.get(f"{common.RESULT_KEY_PREFIX}{task_id}") | |
| if result is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" | |
| ) | |
| result_data: dict[str, str] = pickle.loads(result) | |
| return {"result": result_data} | |
| async def download_result(task_id: str): | |
| result = common.redis_client.get(f"{common.RESULT_KEY_PREFIX}{task_id}") | |
| if result is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" | |
| ) | |
| result_data: dict[str, str] = pickle.loads(result) | |
| if result_data["status"] != "success": | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Task is {result_data['status']}", | |
| ) | |
| file_path = pathlib.Path(result_data["path"]) | |
| if not file_path.exists(): | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, detail="File not found" | |
| ) | |
| return FileResponse( | |
| path=file_path, | |
| filename=file_path.name, | |
| headers={"Content-Length": str(file_path.stat().st_size)}, | |
| media_type="image/png", | |
| ) | |
| def register_single_sr_route(): | |
| async def super_resolution( | |
| file: UploadFile | None = File(default=None), | |
| tile_size: int = Form(default=64, ge=32, le=128), | |
| scale: int = Form(default=4, ge=2, le=8), | |
| skip_alpha: bool = Form(default=False), | |
| resize_to: str | None = Form(default=None), | |
| url: str | None = Form(default=None), | |
| timeout: int = Form( | |
| default=common.PROGRESS_TIMEOUT, ge=1, le=common.MAX_ALLOWED_TIMEOUT | |
| ), | |
| model: str = Form(default=common.MODEL_NAME_DEFAULT), | |
| ): | |
| if (file or url) is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="No file or url provided", | |
| ) | |
| temp = tempfile.NamedTemporaryFile( | |
| dir=settings.get("temp_dir", "./temp"), delete=False | |
| ) | |
| temp_path = pathlib.Path(temp.name) | |
| try: | |
| if url is not None: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(url) | |
| if response.status_code != 200: | |
| return {"message": "Failed to download the image"} | |
| if response.headers.get("Content-Type") not in [ | |
| "image/jpeg", | |
| "image/png", | |
| "image/webp", | |
| ]: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid image format", | |
| ) | |
| temp.write(response.content) | |
| else: | |
| if file.content_type not in ["image/jpeg", "image/png", "image/webp"]: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid image format", | |
| ) | |
| temp.write(file.file.read()) | |
| except Exception as e: | |
| logger.error(f"process image error: {e}") | |
| temp.close() | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to process the image", | |
| ) | |
| resp = common.redis_client.xadd( | |
| common.BASE_STREAM_NAME, | |
| { | |
| "data": pickle.dumps( | |
| { | |
| "input_image": temp_path, | |
| "tile_size": tile_size, | |
| "scale": scale, | |
| "skip_alpha": skip_alpha, | |
| "resize_to": resize_to, | |
| "timeout": timeout, | |
| "model": model, | |
| } | |
| ), | |
| }, | |
| ) | |
| xlength = common.redis_client.xlen(common.BASE_STREAM_NAME) | |
| if xlength > 1: | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{resp.decode('utf-8')}", | |
| pickle.dumps({"status": "pending"}), | |
| ex=86400, | |
| ) | |
| logger.info(f"Task added to queue: {resp.decode('utf-8')}") | |
| return {"message": "Success", "task_id": f"{resp.decode('utf-8')}"} | |
| def register_master(): | |
| async def register_worker( | |
| worker_id: str = Form(...), | |
| worker_url: str = Form(...), | |
| worker_token: str = Form(...), | |
| ): | |
| try: | |
| common.redis_client.set( | |
| f"{common.WORKER_KEY_PREFIX}{worker_id}", | |
| f"{worker_url}|{worker_token}", | |
| ex=settings.get("worker_expire", 120), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to register worker: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to register worker", | |
| ) | |
| return {"message": "Success"} | |
| async def get_workers(): | |
| workers = common.redis_client.keys(f"{common.WORKER_KEY_PREFIX}*") | |
| return { | |
| "workers": [ | |
| { | |
| "id": worker.decode("utf-8").split("_")[-1], | |
| "data": common.redis_client.get(worker).decode("utf-8"), | |
| } | |
| for worker in workers | |
| ] | |
| } | |
| async def super_resolution( | |
| file: UploadFile | None = File(default=None), | |
| tile_size: int = Form(default=64, ge=32, le=128), | |
| scale: int = Form(default=4, ge=2, le=8), | |
| skip_alpha: bool = Form(default=False), | |
| resize_to: str | None = Form(default=None), | |
| url: str | None = Form(default=None), | |
| timeout: int = Form( | |
| default=common.PROGRESS_TIMEOUT, ge=1, le=common.MAX_ALLOWED_TIMEOUT | |
| ), | |
| model: str = Form(default=common.MODEL_NAME_DEFAULT), | |
| ): | |
| """ | |
| 将输入图片分块, 分发给存储在 Redis 中的 worker | |
| 对于客户端来说, 该 /sr 路由和 single 模式是兼容的 | |
| """ | |
| if (file or url) is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="No file or url provided", | |
| ) | |
| workers = common.redis_client.keys(f"{common.WORKER_KEY_PREFIX}*") | |
| if not workers: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="No available worker", | |
| ) | |
| input_temp = tempfile.NamedTemporaryFile( | |
| dir=settings.get("temp_dir", "./temp"), delete=False | |
| ) | |
| input_path = pathlib.Path(input_temp.name) | |
| try: | |
| if url is not None: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(url) | |
| if response.status_code != 200: | |
| return {"message": "Failed to download the image"} | |
| if response.headers.get("Content-Type") not in [ | |
| "image/jpeg", | |
| "image/png", | |
| "image/webp", | |
| ]: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid image format", | |
| ) | |
| input_temp.write(response.content) | |
| else: | |
| if file.content_type not in ["image/jpeg", "image/png", "image/webp"]: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid image format", | |
| ) | |
| input_temp.write(file.file.read()) | |
| except Exception as e: | |
| logger.error(f"process image error: {e}") | |
| input_temp.close() | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to process the image", | |
| ) | |
| workers = common.redis_client.keys(f"{common.WORKER_KEY_PREFIX}*") | |
| if not workers: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="No available worker", | |
| ) | |
| try: | |
| save_dir = pathlib.Path( | |
| f"{settings.get('output_dir','./output')}/{input_temp.name.split('/')[-1]}" | |
| ) | |
| origin_width, origin_height = common.get_image_size(input_path) | |
| origin_tiles_info = common.split_image( | |
| input_path, | |
| save_dir, | |
| common.calculate_grid(origin_width, origin_height, len(workers)), | |
| ) | |
| response = {} | |
| for index, worker_key in enumerate(workers): | |
| worker = common.redis_client.get(worker_key) | |
| worker_url, token = worker.decode("utf-8").split("|") | |
| tile_info = origin_tiles_info[index] | |
| with open(tile_info.filpath, "rb") as tile_file: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get( | |
| worker_url + "/", headers={"X-Token": token} | |
| ) | |
| if resp.status_code != 200: | |
| raise Exception(f"Worker {worker_url} is not available") | |
| resp = await client.post( | |
| url=f"{worker_url}/sr", | |
| files={"file": tile_file}, | |
| data={ | |
| "tile_size": tile_size, | |
| "scale": scale, | |
| "skip_alpha": skip_alpha, | |
| "resize_to": resize_to, | |
| "timeout": timeout, | |
| "model": model, | |
| }, | |
| headers={"X-Token": token}, | |
| ) | |
| if resp.status_code != 200: | |
| raise Exception( | |
| f"Woker {worker_url} failed to process the image: {resp.text}" | |
| ) | |
| resp_dict = resp.json().copy() | |
| resp_dict["tile_info"] = tile_info | |
| response[worker_key] = resp_dict | |
| except Exception as e: | |
| logger.error(f"error: {e}") | |
| input_temp.close() | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to process the image: {e}", | |
| ) | |
| resp = common.redis_client.xadd( | |
| common.DISTRIBUTED_STREAM_NAME, | |
| { | |
| "data": pickle.dumps( | |
| { | |
| "input_image": input_path, | |
| "worker_response": response, | |
| "scale": scale, | |
| } | |
| ) | |
| }, | |
| ) | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{resp.decode('utf-8')}", | |
| pickle.dumps({"status": "pending"}), | |
| ex=86400, | |
| ) | |
| return {"message": "Success", "task_id": f"{resp.decode('utf-8')}"} | |
| def register_slave(): | |
| register_single_sr_route() | |
| def register(): | |
| while True: | |
| try: | |
| with httpx.Client() as client: | |
| resp = client.post( | |
| url=f"{settings.get('master_url')}/register", | |
| data={ | |
| "worker_id": settings.get("worker_id"), | |
| "worker_url": settings.get("worker_url"), | |
| "worker_token": settings.get("token"), | |
| }, | |
| headers={"X-Token": settings.get("master_token")}, | |
| ) | |
| if resp.status_code != 200: | |
| logger.error(f"Failed to register to master: {resp.text}") | |
| except Exception as e: | |
| logger.error(f"Registration error: {e}") | |
| finally: | |
| time.sleep(settings.get("register_interval", 30)) | |
| register_thread = threading.Thread(target=register) | |
| register_thread.daemon = True | |
| register_thread.start() | |
| if __name__ == "__main__": | |
| register_routes() | |
| if settings.get("mode", "single") == "single": | |
| register_single_sr_route() | |
| elif settings.get("mode") == "master": | |
| register_master() | |
| queue_thread = threading.Thread(target=listen_distributed_queue) | |
| queue_thread.daemon = True | |
| queue_thread.start() | |
| else: | |
| register_slave() | |
| if settings.get("mode") != "master": | |
| queue_thread = threading.Thread(target=listen_queue) | |
| queue_thread.daemon = True | |
| queue_thread.start() | |
| if not pathlib.Path(settings.get("temp_dir", "./temp")).exists(): | |
| pathlib.Path(settings.get("temp_dir", "./temp")).mkdir(parents=True) | |
| import uvicorn | |
| try: | |
| uvicorn.run( | |
| app, | |
| host=settings.get("host", "0.0.0.0"), | |
| port=settings.get("port", 39721), | |
| ) | |
| except KeyboardInterrupt: | |
| pass | |
| finally: | |
| logger.info("Shutting down") | |
| common.redis_client.delete(common.BASE_STREAM_NAME) | |
| if settings.get("mode") == "master": | |
| common.redis_client.delete(common.DISTRIBUTED_STREAM_NAME) | |
| shutil.rmtree(settings.get("temp_dir", "./temp")) | |