Spaces:
Paused
Paused
| import os | |
| import cv2 | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi import FastAPI, File, UploadFile, Form, Request | |
| from fastapi.responses import HTMLResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| import AnimeGANv3_src | |
| app = FastAPI() | |
| os.makedirs('output', exist_ok=True) | |
| def inference(img_path, Style, if_face=None): | |
| print(img_path, Style, if_face) | |
| try: | |
| img = cv2.imread(img_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| if Style == "AnimeGANv3_Arcane": | |
| f = "A" | |
| elif Style == "AnimeGANv3_Trump v1.0": | |
| f = "T" | |
| elif Style == "AnimeGANv3_Shinkai": | |
| f = "S" | |
| elif Style == "AnimeGANv3_PortraitSketch": | |
| f = "P" | |
| elif Style == "AnimeGANv3_Hayao": | |
| f = "H" | |
| elif Style == "AnimeGANv3_Disney v1.0": | |
| f = "D" | |
| elif Style == "AnimeGANv3_JP_face v1.0": | |
| f = "J" | |
| elif Style == "AnimeGANv3_Kpop v2.0": | |
| f = "K" | |
| else: | |
| f = "U" | |
| try: | |
| det_face = True if if_face=="Yes" else False | |
| output = AnimeGANv3_src.Convert(img, f, det_face) | |
| save_path = f"output/out.{img_path.rsplit('.')[-1]}" | |
| cv2.imwrite(save_path, output[:, :, ::-1]) | |
| return output, save_path | |
| except RuntimeError as error: | |
| print('Error', error) | |
| except Exception as error: | |
| print('global exception', error) | |
| return None, None | |
| async def inference_api(file: UploadFile = File(...), Style: str = Form(...), if_face: str = Form(...)): | |
| try: | |
| contents = await file.read() | |
| img_path = f"input.{file.filename}" | |
| with open(img_path, "wb") as f: | |
| f.write(contents) | |
| output, save_path = inference(img_path, Style, if_face) | |
| return FileResponse(save_path) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| def index() -> FileResponse: | |
| return FileResponse(path="/app/static/index.html", media_type="text/html") | |