nsave commited on
Commit
d7da59a
·
verified ·
1 Parent(s): fde6b18

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +162 -0
main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import logging
4
+ import os
5
+ import sys
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+
9
+ import uvicorn
10
+ from config import Config
11
+ from fastapi import FastAPI
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.staticfiles import StaticFiles
14
+ from PIL import Image
15
+ from pydantic import BaseModel
16
+
17
+
18
+ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
19
+
20
+ from utils.wrapper import StreamDiffusionWrapper
21
+
22
+
23
+ logger = logging.getLogger("uvicorn")
24
+ PROJECT_DIR = Path(__file__).parent.parent
25
+
26
+
27
+ class PredictInputModel(BaseModel):
28
+ """
29
+ The input model for the /predict endpoint.
30
+ """
31
+
32
+ prompt: str
33
+
34
+
35
+ class PredictResponseModel(BaseModel):
36
+ """
37
+ The response model for the /predict endpoint.
38
+ """
39
+
40
+ base64_image: str
41
+
42
+
43
+ class UpdatePromptResponseModel(BaseModel):
44
+ """
45
+ The response model for the /update_prompt endpoint.
46
+ """
47
+
48
+ prompt: str
49
+
50
+
51
+ class Api:
52
+ def __init__(self, config: Config) -> None:
53
+ """
54
+ Initialize the API.
55
+
56
+ Parameters
57
+ ----------
58
+ config : Config
59
+ The configuration.
60
+ """
61
+ self.config = config
62
+ self.stream_diffusion = StreamDiffusionWrapper(
63
+ mode=config.mode,
64
+ model_id_or_path=config.model_id_or_path,
65
+ lora_dict=config.lora_dict,
66
+ lcm_lora_id=config.lcm_lora_id,
67
+ vae_id=config.vae_id,
68
+ device=config.device,
69
+ dtype=config.dtype,
70
+ acceleration=config.acceleration,
71
+ t_index_list=config.t_index_list,
72
+ warmup=config.warmup,
73
+ use_safety_checker=config.use_safety_checker,
74
+ cfg_type="none",
75
+ )
76
+ self.app = FastAPI()
77
+ self.app.add_api_route(
78
+ "/api/predict",
79
+ self._predict,
80
+ methods=["POST"],
81
+ response_model=PredictResponseModel,
82
+ )
83
+ self.app.add_middleware(
84
+ CORSMiddleware,
85
+ allow_origins=["*"],
86
+ allow_credentials=True,
87
+ allow_methods=["*"],
88
+ allow_headers=["*"],
89
+ )
90
+ self.app.mount("/", StaticFiles(directory="./frontend/dist", html=True), name="public")
91
+
92
+ self._predict_lock = asyncio.Lock()
93
+ self._update_prompt_lock = asyncio.Lock()
94
+
95
+ async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
96
+ """
97
+ Predict an image and return.
98
+
99
+ Parameters
100
+ ----------
101
+ inp : PredictInputModel
102
+ The input.
103
+
104
+ Returns
105
+ -------
106
+ PredictResponseModel
107
+ The prediction result.
108
+ """
109
+ async with self._predict_lock:
110
+ return PredictResponseModel(base64_image=self._pil_to_base64(self.stream_diffusion(prompt=inp.prompt)))
111
+
112
+ def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
113
+ """
114
+ Convert a PIL image to base64.
115
+
116
+ Parameters
117
+ ----------
118
+ image : Image.Image
119
+ The PIL image.
120
+
121
+ format : str
122
+ The image format, by default "JPEG".
123
+
124
+ Returns
125
+ -------
126
+ bytes
127
+ The base64 image.
128
+ """
129
+ buffered = BytesIO()
130
+ image.convert("RGB").save(buffered, format=format)
131
+ return base64.b64encode(buffered.getvalue()).decode("ascii")
132
+
133
+ def _base64_to_pil(self, base64_image: str) -> Image.Image:
134
+ """
135
+ Convert a base64 image to PIL.
136
+
137
+ Parameters
138
+ ----------
139
+ base64_image : str
140
+ The base64 image.
141
+
142
+ Returns
143
+ -------
144
+ Image.Image
145
+ The PIL image.
146
+ """
147
+ if "base64," in base64_image:
148
+ base64_image = base64_image.split("base64,")[1]
149
+ return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ from config import Config
154
+
155
+ config = Config()
156
+
157
+ uvicorn.run(
158
+ Api(config).app,
159
+ host=config.host,
160
+ port=config.port,
161
+ workers=config.workers,
162
+ )