File size: 2,279 Bytes
13f9424
 
d38d1f9
 
 
 
 
 
7bb83ed
b0c0488
b117429
 
13f9424
b117429
13f9424
b117429
13f9424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237989a
13f9424
237989a
 
 
 
 
 
 
 
 
 
 
 
13f9424
 
 
 
 
b0c0488
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import cv2
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import FastAPI, File, UploadFile, HTTPException


import AnimeGANv3_src

app = FastAPI()

os.makedirs('output', exist_ok=True)

def inference(img_path, Style, if_face=None):
    print(img_path, Style, if_face)
    try:
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if Style == "AnimeGANv3_Arcane":
            f = "A"
        elif Style == "AnimeGANv3_Trump v1.0":
            f = "T"
        elif Style == "AnimeGANv3_Shinkai":
            f = "S"
        elif Style == "AnimeGANv3_PortraitSketch":
            f = "P"
        elif Style == "AnimeGANv3_Hayao":
            f = "H"
        elif Style == "AnimeGANv3_Disney v1.0":
            f = "D"
        elif Style == "AnimeGANv3_JP_face v1.0":
            f = "J"
        elif Style == "AnimeGANv3_Kpop v2.0":
            f = "K"
        else:
            f = "U"

        try:
            det_face = True if if_face=="Yes" else False
            output = AnimeGANv3_src.Convert(img, f, det_face)
            save_path = f"output/out.{img_path.rsplit('.')[-1]}"
            cv2.imwrite(save_path, output[:, :, ::-1])
            return output, save_path
        except RuntimeError as error:
            print('Error', error)
    except Exception as error:
        print('global exception', error)
        return None, None


@app.post("/inference/")
async def inference_api(file: UploadFile = File(...), Style: str = Form(...), if_face: str = Form(...)):
    try:
        contents = await file.read()
        img_path = f"input.{file.filename}"
        with open(img_path, "wb") as f:
            f.write(contents)

        output, save_path = inference(img_path, Style, if_face)
        
        return FileResponse(save_path)
    except Exception as e:
        return {"error": str(e)}

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")