krau
init commit
7d54ba7 unverified
import pathlib
import pickle
import shutil
import tempfile
import threading
import time
import httpx
from fastapi import (
Depends,
FastAPI,
File,
Form,
Header,
HTTPException,
UploadFile,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from logger import logger
import common
from config import settings
from sr_queue import listen_distributed_queue, listen_queue
async def verify_token(x_token: str = Header()):
if x_token != settings.get("token"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid X-Token"
)
app = FastAPI(
dependencies=[Depends(verify_token)],
title="Super Resolution API",
description="Super Resolution API for Anime and Illustration",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def register_routes():
@app.get("/")
async def root():
return {
"message": f"Super Resolution API is running as {settings.get('mode', 'single')} mode"
}
@app.get("/result/{task_id}")
async def get_result(task_id: str):
result = common.redis_client.get(f"{common.RESULT_KEY_PREFIX}{task_id}")
if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
)
result_data: dict[str, str] = pickle.loads(result)
return {"result": result_data}
@app.get("/result/{task_id}/download")
async def download_result(task_id: str):
result = common.redis_client.get(f"{common.RESULT_KEY_PREFIX}{task_id}")
if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
)
result_data: dict[str, str] = pickle.loads(result)
if result_data["status"] != "success":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Task is {result_data['status']}",
)
file_path = pathlib.Path(result_data["path"])
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
return FileResponse(
path=file_path,
filename=file_path.name,
headers={"Content-Length": str(file_path.stat().st_size)},
media_type="image/png",
)
def register_single_sr_route():
@app.post("/sr")
async def super_resolution(
file: UploadFile | None = File(default=None),
tile_size: int = Form(default=64, ge=32, le=128),
scale: int = Form(default=4, ge=2, le=8),
skip_alpha: bool = Form(default=False),
resize_to: str | None = Form(default=None),
url: str | None = Form(default=None),
timeout: int = Form(
default=common.PROGRESS_TIMEOUT, ge=1, le=common.MAX_ALLOWED_TIMEOUT
),
model: str = Form(default=common.MODEL_NAME_DEFAULT),
):
if (file or url) is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No file or url provided",
)
temp = tempfile.NamedTemporaryFile(
dir=settings.get("temp_dir", "./temp"), delete=False
)
temp_path = pathlib.Path(temp.name)
try:
if url is not None:
async with httpx.AsyncClient() as client:
response = await client.get(url)
if response.status_code != 200:
return {"message": "Failed to download the image"}
if response.headers.get("Content-Type") not in [
"image/jpeg",
"image/png",
"image/webp",
]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image format",
)
temp.write(response.content)
else:
if file.content_type not in ["image/jpeg", "image/png", "image/webp"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image format",
)
temp.write(file.file.read())
except Exception as e:
logger.error(f"process image error: {e}")
temp.close()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process the image",
)
resp = common.redis_client.xadd(
common.BASE_STREAM_NAME,
{
"data": pickle.dumps(
{
"input_image": temp_path,
"tile_size": tile_size,
"scale": scale,
"skip_alpha": skip_alpha,
"resize_to": resize_to,
"timeout": timeout,
"model": model,
}
),
},
)
xlength = common.redis_client.xlen(common.BASE_STREAM_NAME)
if xlength > 1:
common.redis_client.set(
f"{common.RESULT_KEY_PREFIX}{resp.decode('utf-8')}",
pickle.dumps({"status": "pending"}),
ex=86400,
)
logger.info(f"Task added to queue: {resp.decode('utf-8')}")
return {"message": "Success", "task_id": f"{resp.decode('utf-8')}"}
def register_master():
@app.post("/register")
async def register_worker(
worker_id: str = Form(...),
worker_url: str = Form(...),
worker_token: str = Form(...),
):
try:
common.redis_client.set(
f"{common.WORKER_KEY_PREFIX}{worker_id}",
f"{worker_url}|{worker_token}",
ex=settings.get("worker_expire", 120),
)
except Exception as e:
logger.error(f"Failed to register worker: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to register worker",
)
return {"message": "Success"}
@app.get("/workers")
async def get_workers():
workers = common.redis_client.keys(f"{common.WORKER_KEY_PREFIX}*")
return {
"workers": [
{
"id": worker.decode("utf-8").split("_")[-1],
"data": common.redis_client.get(worker).decode("utf-8"),
}
for worker in workers
]
}
@app.post("/sr")
async def super_resolution(
file: UploadFile | None = File(default=None),
tile_size: int = Form(default=64, ge=32, le=128),
scale: int = Form(default=4, ge=2, le=8),
skip_alpha: bool = Form(default=False),
resize_to: str | None = Form(default=None),
url: str | None = Form(default=None),
timeout: int = Form(
default=common.PROGRESS_TIMEOUT, ge=1, le=common.MAX_ALLOWED_TIMEOUT
),
model: str = Form(default=common.MODEL_NAME_DEFAULT),
):
"""
将输入图片分块, 分发给存储在 Redis 中的 worker
对于客户端来说, 该 /sr 路由和 single 模式是兼容的
"""
if (file or url) is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No file or url provided",
)
workers = common.redis_client.keys(f"{common.WORKER_KEY_PREFIX}*")
if not workers:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No available worker",
)
input_temp = tempfile.NamedTemporaryFile(
dir=settings.get("temp_dir", "./temp"), delete=False
)
input_path = pathlib.Path(input_temp.name)
try:
if url is not None:
async with httpx.AsyncClient() as client:
response = await client.get(url)
if response.status_code != 200:
return {"message": "Failed to download the image"}
if response.headers.get("Content-Type") not in [
"image/jpeg",
"image/png",
"image/webp",
]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image format",
)
input_temp.write(response.content)
else:
if file.content_type not in ["image/jpeg", "image/png", "image/webp"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image format",
)
input_temp.write(file.file.read())
except Exception as e:
logger.error(f"process image error: {e}")
input_temp.close()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process the image",
)
workers = common.redis_client.keys(f"{common.WORKER_KEY_PREFIX}*")
if not workers:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No available worker",
)
try:
save_dir = pathlib.Path(
f"{settings.get('output_dir','./output')}/{input_temp.name.split('/')[-1]}"
)
origin_width, origin_height = common.get_image_size(input_path)
origin_tiles_info = common.split_image(
input_path,
save_dir,
common.calculate_grid(origin_width, origin_height, len(workers)),
)
response = {}
for index, worker_key in enumerate(workers):
worker = common.redis_client.get(worker_key)
worker_url, token = worker.decode("utf-8").split("|")
tile_info = origin_tiles_info[index]
with open(tile_info.filpath, "rb") as tile_file:
async with httpx.AsyncClient() as client:
resp = await client.get(
worker_url + "/", headers={"X-Token": token}
)
if resp.status_code != 200:
raise Exception(f"Worker {worker_url} is not available")
resp = await client.post(
url=f"{worker_url}/sr",
files={"file": tile_file},
data={
"tile_size": tile_size,
"scale": scale,
"skip_alpha": skip_alpha,
"resize_to": resize_to,
"timeout": timeout,
"model": model,
},
headers={"X-Token": token},
)
if resp.status_code != 200:
raise Exception(
f"Woker {worker_url} failed to process the image: {resp.text}"
)
resp_dict = resp.json().copy()
resp_dict["tile_info"] = tile_info
response[worker_key] = resp_dict
except Exception as e:
logger.error(f"error: {e}")
input_temp.close()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process the image: {e}",
)
resp = common.redis_client.xadd(
common.DISTRIBUTED_STREAM_NAME,
{
"data": pickle.dumps(
{
"input_image": input_path,
"worker_response": response,
"scale": scale,
}
)
},
)
common.redis_client.set(
f"{common.RESULT_KEY_PREFIX}{resp.decode('utf-8')}",
pickle.dumps({"status": "pending"}),
ex=86400,
)
return {"message": "Success", "task_id": f"{resp.decode('utf-8')}"}
def register_slave():
register_single_sr_route()
def register():
while True:
try:
with httpx.Client() as client:
resp = client.post(
url=f"{settings.get('master_url')}/register",
data={
"worker_id": settings.get("worker_id"),
"worker_url": settings.get("worker_url"),
"worker_token": settings.get("token"),
},
headers={"X-Token": settings.get("master_token")},
)
if resp.status_code != 200:
logger.error(f"Failed to register to master: {resp.text}")
except Exception as e:
logger.error(f"Registration error: {e}")
finally:
time.sleep(settings.get("register_interval", 30))
register_thread = threading.Thread(target=register)
register_thread.daemon = True
register_thread.start()
if __name__ == "__main__":
register_routes()
if settings.get("mode", "single") == "single":
register_single_sr_route()
elif settings.get("mode") == "master":
register_master()
queue_thread = threading.Thread(target=listen_distributed_queue)
queue_thread.daemon = True
queue_thread.start()
else:
register_slave()
if settings.get("mode") != "master":
queue_thread = threading.Thread(target=listen_queue)
queue_thread.daemon = True
queue_thread.start()
if not pathlib.Path(settings.get("temp_dir", "./temp")).exists():
pathlib.Path(settings.get("temp_dir", "./temp")).mkdir(parents=True)
import uvicorn
try:
uvicorn.run(
app,
host=settings.get("host", "0.0.0.0"),
port=settings.get("port", 39721),
)
except KeyboardInterrupt:
pass
finally:
logger.info("Shutting down")
common.redis_client.delete(common.BASE_STREAM_NAME)
if settings.get("mode") == "master":
common.redis_client.delete(common.DISTRIBUTED_STREAM_NAME)
shutil.rmtree(settings.get("temp_dir", "./temp"))