Ashrafb commited on
Commit
bef3703
·
verified ·
1 Parent(s): 6c4d9c1

Rename app.py to main.py

Browse files
Files changed (1) hide show
  1. app.py → main.py +29 -22
app.py → main.py RENAMED
@@ -1,10 +1,19 @@
1
- import os
2
-
 
 
 
 
 
 
 
 
3
  import cv2
4
- import gradio as gr
5
  import numpy as np
6
- import onnxruntime as ort
7
  from PIL import Image
 
 
 
8
 
9
  _sess_options = ort.SessionOptions()
10
  _sess_options.intra_op_num_threads = os.cpu_count()
@@ -12,9 +21,7 @@ MODEL_SESS = ort.InferenceSession(
12
  "cartoonizer.onnx", _sess_options, providers=["CPUExecutionProvider"]
13
  )
14
 
15
-
16
- def preprocess_image(image: Image) -> np.ndarray:
17
- image = np.array(image)
18
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
19
 
20
  h, w, c = np.shape(image)
@@ -29,26 +36,26 @@ def preprocess_image(image: Image) -> np.ndarray:
29
  image = image.astype(np.float32) / 127.5 - 1
30
  return np.expand_dims(image, axis=0)
31
 
32
-
33
- def inference(image: np.ndarray) -> Image:
34
  image = preprocess_image(image)
35
  results = MODEL_SESS.run(None, {"input_photo:0": image})
36
  output = (np.squeeze(results[0]) + 1.0) * 127.5
37
  output = np.clip(output, 0, 255).astype(np.uint8)
38
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
39
- return Image.fromarray(output)
40
 
 
41
 
42
- title = "Generate cartoonized images"
43
- article = "Demo of CartoonGAN model (https://systemerrorwang.github.io/White-box-Cartoonization/). \nDemo image is from https://unsplash.com/photos/f0SgAs27BYI."
 
 
 
 
 
44
 
45
- iface = gr.Interface(
46
- inference,
47
- inputs=gr.inputs.Image(type="pil", label="Input Image"),
48
- outputs="image",
49
- title=title,
50
- article=article,
51
- allow_flagging="never",
52
- examples=[["mountain.jpeg"]],
53
- )
54
- iface.launch()
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi import FastAPI, File, UploadFile, Form, Request
3
+ from fastapi.responses import HTMLResponse, FileResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.responses import StreamingResponse
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from fastapi.responses import StreamingResponse
11
  import cv2
 
12
  import numpy as np
 
13
  from PIL import Image
14
+ import io
15
+
16
+ import onnxruntime as ort
17
 
18
  _sess_options = ort.SessionOptions()
19
  _sess_options.intra_op_num_threads = os.cpu_count()
 
21
  "cartoonizer.onnx", _sess_options, providers=["CPUExecutionProvider"]
22
  )
23
 
24
+ def preprocess_image(image: np.ndarray) -> np.ndarray:
 
 
25
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
26
 
27
  h, w, c = np.shape(image)
 
36
  image = image.astype(np.float32) / 127.5 - 1
37
  return np.expand_dims(image, axis=0)
38
 
39
+ def inference(image: np.ndarray) -> np.ndarray:
 
40
  image = preprocess_image(image)
41
  results = MODEL_SESS.run(None, {"input_photo:0": image})
42
  output = (np.squeeze(results[0]) + 1.0) * 127.5
43
  output = np.clip(output, 0, 255).astype(np.uint8)
44
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
45
+ return output
46
 
47
+ app = FastAPI()
48
 
49
+ @app.post("/cartoonize/")
50
+ async def cartoonize_image(file: UploadFile = File(...)):
51
+ contents = await file.read()
52
+ image = Image.open(io.BytesIO(contents))
53
+ image = np.array(image)
54
+ cartoonized_image = inference(image)
55
+ return StreamingResponse(io.BytesIO(cv2.imencode('.jpg', cartoonized_image)[1]), media_type="image/jpeg")
56
 
57
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
58
+
59
+ @app.get("/")
60
+ def index() -> FileResponse:
61
+ return FileResponse(path="/app/static/index.html", media_type="text/html")