Ashrafb commited on
Commit
13f9424
·
verified ·
1 Parent(s): 85324f2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +70 -12
main.py CHANGED
@@ -1,16 +1,74 @@
1
- import cv2,os
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import AnimeGANv3_src
3
- if __name__ == '__main__':
4
 
5
- f = 'A'
6
- input_imgs_path = r'../../v3-usa\dataset\USA\val'
7
- # input_imgs_path = r'/mnt/data/xinchen/v3-usa/dataset/USA/val'
8
- output_path = 'AnimeGANv3_usa_64_output'
9
- # img = cv2.imread(os.path.join(input_imgs_path, os.listdir(input_imgs_path)[0]))
10
- img = cv2.imread(os.path.join(input_imgs_path, 'jp_16.png'))
11
- out = AnimeGANv3_src.Convert(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), f, True)
12
- # cv2.imshow('d', cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
13
- # cv2.waitKey(0)
14
- cv2.imwrite('a.jpg', cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
15
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from fastapi import FastAPI, File, UploadFile
4
+ from fastapi import FastAPI, File, UploadFile, Form, Request
5
+ from fastapi.responses import HTMLResponse, FileResponse
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.templating import Jinja2Templates
8
+ from fastapi import FastAPI, File, UploadFile, HTTPException
9
+ from fastapi.responses import JSONResponse
10
+ from fastapi.responses import StreamingResponse
11
+ from fastapi import FastAPI, File, UploadFile
12
+ from fastapi.responses import FileResponse
13
+ from pydantic import BaseModel
14
+ import shutil
15
  import AnimeGANv3_src
 
16
 
17
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
18
 
19
+ os.makedirs('output', exist_ok=True)
20
 
21
+ class InferenceRequest(BaseModel):
22
+ img_path: str
23
+ Style: str
24
+ if_face: str
25
+
26
+ def inference(img_path, Style, if_face=None):
27
+ print(img_path, Style, if_face)
28
+ try:
29
+ img = cv2.imread(img_path)
30
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31
+ if Style == "AnimeGANv3_Arcane":
32
+ f = "A"
33
+ elif Style == "AnimeGANv3_Trump v1.0":
34
+ f = "T"
35
+ elif Style == "AnimeGANv3_Shinkai":
36
+ f = "S"
37
+ elif Style == "AnimeGANv3_PortraitSketch":
38
+ f = "P"
39
+ elif Style == "AnimeGANv3_Hayao":
40
+ f = "H"
41
+ elif Style == "AnimeGANv3_Disney v1.0":
42
+ f = "D"
43
+ elif Style == "AnimeGANv3_JP_face v1.0":
44
+ f = "J"
45
+ elif Style == "AnimeGANv3_Kpop v2.0":
46
+ f = "K"
47
+ else:
48
+ f = "U"
49
+
50
+ try:
51
+ det_face = True if if_face=="Yes" else False
52
+ output = AnimeGANv3_src.Convert(img, f, det_face)
53
+ save_path = f"output/out.{img_path.rsplit('.')[-1]}"
54
+ cv2.imwrite(save_path, output[:, :, ::-1])
55
+ return output, save_path
56
+ except RuntimeError as error:
57
+ print('Error', error)
58
+ except Exception as error:
59
+ print('global exception', error)
60
+ return None, None
61
+
62
+ @app.post("/inference/")
63
+ async def inference_api(request: InferenceRequest):
64
+ img_path = request.img_path
65
+ Style = request.Style
66
+ if_face = request.if_face
67
+ output, save_path = inference(img_path, Style, if_face)
68
+ return {"output": output, "save_path": save_path}
69
+
70
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
71
+
72
+ @app.get("/")
73
+ def index() -> FileResponse:
74
+ return FileResponse(path="/app/static/index.html", media_type="text/html")