Spaces:
Paused
Paused
File size: 2,279 Bytes
13f9424 d38d1f9 7bb83ed b0c0488 b117429 13f9424 b117429 13f9424 b117429 13f9424 237989a 13f9424 237989a 13f9424 b0c0488 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import os
import cv2
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import FastAPI, File, UploadFile, HTTPException
import AnimeGANv3_src
app = FastAPI()
os.makedirs('output', exist_ok=True)
def inference(img_path, Style, if_face=None):
print(img_path, Style, if_face)
try:
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if Style == "AnimeGANv3_Arcane":
f = "A"
elif Style == "AnimeGANv3_Trump v1.0":
f = "T"
elif Style == "AnimeGANv3_Shinkai":
f = "S"
elif Style == "AnimeGANv3_PortraitSketch":
f = "P"
elif Style == "AnimeGANv3_Hayao":
f = "H"
elif Style == "AnimeGANv3_Disney v1.0":
f = "D"
elif Style == "AnimeGANv3_JP_face v1.0":
f = "J"
elif Style == "AnimeGANv3_Kpop v2.0":
f = "K"
else:
f = "U"
try:
det_face = True if if_face=="Yes" else False
output = AnimeGANv3_src.Convert(img, f, det_face)
save_path = f"output/out.{img_path.rsplit('.')[-1]}"
cv2.imwrite(save_path, output[:, :, ::-1])
return output, save_path
except RuntimeError as error:
print('Error', error)
except Exception as error:
print('global exception', error)
return None, None
@app.post("/inference/")
async def inference_api(file: UploadFile = File(...), Style: str = Form(...), if_face: str = Form(...)):
try:
contents = await file.read()
img_path = f"input.{file.filename}"
with open(img_path, "wb") as f:
f.write(contents)
output, save_path = inference(img_path, Style, if_face)
return FileResponse(save_path)
except Exception as e:
return {"error": str(e)}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
|