test111 / main.py
killbill007's picture
Upload 1398 files
35cdf61 verified
from fastapi import FastAPI, Request, Response, HTTPException, Query, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import io
from scripts.db_helper import insert_ban , search_ban
import uuid
from typing import Optional
import numpy as np
import cv2
import base64
from scripts.comfyui import ComfyUI
from scripts.liveportrait import LP
import os
ui = None
lp = None
def generate_response(response_message, response_status ,uuid_code, metadata):
response_dict = {
"response_message" : response_message,
"response_status" : response_status,
"data" :{
"UUID" : uuid_code,
"metadata" : metadata
}
}
return response_dict
def image_to_base64(image):
_, buffer = cv2.imencode('.png', image)
image_bytes = buffer.tobytes()
base64_image = base64.b64encode(image_bytes).decode('utf-8')
return base64_image
def base64_to_image(base64_string):
image_bytes = base64.b64decode(base64_string)
np_array = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
return image
app = FastAPI()
@app.on_event("startup")
async def startup_event():
print("Application is starting up...")
global ui
ui = ComfyUI()
for i in range(10):
if ui.UI():
break
global lp
lp = LP(port = ui.port)
@app.on_event("shutdown")
async def shutdown_event():
print("Application is shutting down...")
global ui
ui = None
global lp
lp = None
# api_key = '4b95dfe8-4644-46ce-a4fe-648d6d4860a4'
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
db_params = {
'dbname': 'API_DB',
'user': 'postgres',
'password': '4b95dfe8-4644-46ce-a4fe-648d6d4860a4',
'host': '44.208.52.100',
'port': '5432'
}
async def custom_middleware(request: Request, call_next):
request_api_key = request.headers.get("api_key")
client_ip = request.client.host
if search_ban(db_params, client_ip):
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
# if request_api_key != api_key:
# insert_ban(db_params, client_ip, "Access using wrong api_key")
# return JSONResponse(status_code=401, content={"detail": "Unauthorized: Invalid API Key"})
response = await call_next(request)
if (response.status_code == 404) or (response.status_code == 405):
insert_ban(db_params, client_ip, "Access using unapproved method/endpoint")
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
return response
app.middleware('http')(custom_middleware)
from scripts.db_helper import get_request_data
from scripts.s3 import upload_to_s3, generate_presigned_url
@app.post("/genrate_video")
async def genrate_video(request_id: uuid.UUID):
request_id = str(request_id)
request_data = get_request_data(
db_params=db_params,
request_id=request_id
)
if request_data is None:
return generate_response('Incorrect Request', False ,'', [])
print("update request status to processing")
video = None
try:
print(1)
lp = LP( port = ui.port)
video = lp.LP(request_id, request_data[0], f'/home/ubuntu/AI/videos/{request_data[1]}.mp4')
print(2)
except:
print(3)
server_url = f"http://127.0.0.1:{ui.port}"
if not ui.is_server_ready(server_url, 5):
ui.restart_if_down()
#global lp
lp = LP( port = ui.port)
print(4)
video = lp.LP(request_id, request_data[0], f'/home/ubuntu/AI/videos/{request_data[1]}.mp4')
if video is None:
if ui.is_server_ready(server_url, 5):
print("update request status to error")
bucket = 'app-faceanimate-s3'
local_file_path = f'/home/ubuntu/AI/ComfyUI/output/{video[0]}'
s3_key = f'{request_id}/video.mp4'
upload_to_s3(bucket, local_file_path, s3_key)#apply check and retry
for video_path in video[1]:
os.remove(video_path)
video_url = generate_presigned_url(bucket, s3_key)#apply check and retry
print("update request status to processed")
return generate_response('Video Genrated', True ,'', [video_url])
if __name__ == '__main__':
number_of_workers = 2
import uvicorn
from scripts.comfyui import setup_database
setup_database(number_of_workers)
uvicorn.run("main:app", host='0.0.0.0', port=3000, workers=number_of_workers )#, limit_max_requests=200 )#, timeout_keep_alive=5)