3dcre commited on
Commit
266260c
·
verified ·
1 Parent(s): 9a787da

Create api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +273 -0
api_server.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import base64
7
+ import logging
8
+ import logging.handlers
9
+ import os
10
+ import sys
11
+ import tempfile
12
+ import threading
13
+ import traceback
14
+ import uuid
15
+ from io import BytesIO
16
+
17
+ import torch
18
+ import trimesh
19
+ import uvicorn
20
+ from PIL import Image
21
+ from fastapi import FastAPI, Request
22
+ from fastapi.responses import JSONResponse, FileResponse
23
+
24
+ from hy3dgen.rembg import BackgroundRemover
25
+ from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FloaterRemover, DegenerateFaceRemover, FaceReducer
26
+ from hy3dgen.texgen import Hunyuan3DPaintPipeline
27
+ from hy3dgen.text2image import HunyuanDiTPipeline
28
+
29
+ LOGDIR = '.'
30
+
31
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
32
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
33
+
34
+ handler = None
35
+
36
+
37
+ def build_logger(logger_name, logger_filename):
38
+ global handler
39
+
40
+ formatter = logging.Formatter(
41
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
42
+ datefmt="%Y-%m-%d %H:%M:%S",
43
+ )
44
+
45
+ # Set the format of root handlers
46
+ if not logging.getLogger().handlers:
47
+ logging.basicConfig(level=logging.INFO)
48
+ logging.getLogger().handlers[0].setFormatter(formatter)
49
+
50
+ # Redirect stdout and stderr to loggers
51
+ stdout_logger = logging.getLogger("stdout")
52
+ stdout_logger.setLevel(logging.INFO)
53
+ sl = StreamToLogger(stdout_logger, logging.INFO)
54
+ sys.stdout = sl
55
+
56
+ stderr_logger = logging.getLogger("stderr")
57
+ stderr_logger.setLevel(logging.ERROR)
58
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
59
+ sys.stderr = sl
60
+
61
+ # Get logger
62
+ logger = logging.getLogger(logger_name)
63
+ logger.setLevel(logging.INFO)
64
+
65
+ # Add a file handler for all loggers
66
+ if handler is None:
67
+ os.makedirs(LOGDIR, exist_ok=True)
68
+ filename = os.path.join(LOGDIR, logger_filename)
69
+ handler = logging.handlers.TimedRotatingFileHandler(
70
+ filename, when='D', utc=True, encoding='UTF-8')
71
+ handler.setFormatter(formatter)
72
+
73
+ for name, item in logging.root.manager.loggerDict.items():
74
+ if isinstance(item, logging.Logger):
75
+ item.addHandler(handler)
76
+
77
+ return logger
78
+
79
+
80
+ class StreamToLogger(object):
81
+ """
82
+ Fake file-like stream object that redirects writes to a logger instance.
83
+ """
84
+
85
+ def __init__(self, logger, log_level=logging.INFO):
86
+ self.terminal = sys.stdout
87
+ self.logger = logger
88
+ self.log_level = log_level
89
+ self.linebuf = ''
90
+
91
+ def __getattr__(self, attr):
92
+ return getattr(self.terminal, attr)
93
+
94
+ def write(self, buf):
95
+ temp_linebuf = self.linebuf + buf
96
+ self.linebuf = ''
97
+ for line in temp_linebuf.splitlines(True):
98
+ # From the io.TextIOWrapper docs:
99
+ # On output, if newline is None, any '\n' characters written
100
+ # are translated to the system default line separator.
101
+ # By default sys.stdout.write() expects '\n' newlines and then
102
+ # translates them so this is still cross platform.
103
+ if line[-1] == '\n':
104
+ self.logger.log(self.log_level, line.rstrip())
105
+ else:
106
+ self.linebuf += line
107
+
108
+ def flush(self):
109
+ if self.linebuf != '':
110
+ self.logger.log(self.log_level, self.linebuf.rstrip())
111
+ self.linebuf = ''
112
+
113
+
114
+ def pretty_print_semaphore(semaphore):
115
+ if semaphore is None:
116
+ return "None"
117
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
118
+
119
+
120
+ SAVE_DIR = 'gradio_cache'
121
+ os.makedirs(SAVE_DIR, exist_ok=True)
122
+
123
+ worker_id = str(uuid.uuid4())[:6]
124
+ logger = build_logger("controller", f"{SAVE_DIR}/controller.log")
125
+
126
+
127
+ def load_image_from_base64(image):
128
+ return Image.open(BytesIO(base64.b64decode(image)))
129
+
130
+
131
+ class ModelWorker:
132
+ def __init__(self, model_path='tencent/Hunyuan3D-2', device='cuda'):
133
+ self.model_path = model_path
134
+ self.worker_id = worker_id
135
+ self.device = device
136
+ logger.info(f"Loading the model {model_path} on worker {worker_id} ...")
137
+
138
+ self.rembg = BackgroundRemover()
139
+ self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path, device=device)
140
+ self.pipeline_t2i = HunyuanDiTPipeline('Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled',
141
+ device=device)
142
+ self.pipeline_tex = Hunyuan3DPaintPipeline.from_pretrained(model_path)
143
+
144
+ def get_queue_length(self):
145
+ if model_semaphore is None:
146
+ return 0
147
+ else:
148
+ return args.limit_model_concurrency - model_semaphore._value + (len(
149
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
150
+
151
+ def get_status(self):
152
+ return {
153
+ "speed": 1,
154
+ "queue_length": self.get_queue_length(),
155
+ }
156
+
157
+ @torch.inference_mode()
158
+ def generate(self, uid, params):
159
+ if 'image' in params:
160
+ image = params["image"]
161
+ image = load_image_from_base64(image)
162
+ else:
163
+ if 'text' in params:
164
+ text = params["text"]
165
+ image = self.pipeline_t2i(text)
166
+ else:
167
+ raise ValueError("No input image or text provided")
168
+
169
+ image = self.rembg(image)
170
+ params['image'] = image
171
+
172
+ if 'mesh' in params:
173
+ mesh = trimesh.load(BytesIO(base64.b64decode(params["mesh"])), file_type='glb')
174
+ else:
175
+ seed = params.get("seed", 1234)
176
+ params['generator'] = torch.Generator(self.device).manual_seed(seed)
177
+ params['octree_resolution'] = params.get("octree_resolution", 256)
178
+ params['num_inference_steps'] = params.get("num_inference_steps", 30)
179
+ params['guidance_scale'] = params.get('guidance_scale', 7.5)
180
+ params['mc_algo'] = 'mc'
181
+ mesh = self.pipeline(**params)[0]
182
+
183
+ if params.get('texture', False):
184
+ mesh = FloaterRemover()(mesh)
185
+ mesh = DegenerateFaceRemover()(mesh)
186
+ mesh = FaceReducer()(mesh, max_facenum=params.get('face_count', 40000))
187
+ mesh = self.pipeline_tex(mesh, image)
188
+
189
+ with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as temp_file:
190
+ mesh.export(temp_file.name)
191
+ mesh = trimesh.load(temp_file.name)
192
+ temp_file.close()
193
+ os.unlink(temp_file.name)
194
+ save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb')
195
+ mesh.export(save_path)
196
+
197
+ torch.cuda.empty_cache()
198
+ return save_path, uid
199
+
200
+
201
+ app = FastAPI()
202
+
203
+
204
+ @app.post("/generate")
205
+ async def generate(request: Request):
206
+ logger.info("Worker generating...")
207
+ params = await request.json()
208
+ uid = uuid.uuid4()
209
+ try:
210
+ file_path, uid = worker.generate(uid, params)
211
+ return FileResponse(file_path)
212
+ except ValueError as e:
213
+ traceback.print_exc()
214
+ print("Caught ValueError:", e)
215
+ ret = {
216
+ "text": server_error_msg,
217
+ "error_code": 1,
218
+ }
219
+ return JSONResponse(ret, status_code=404)
220
+ except torch.cuda.CudaError as e:
221
+ print("Caught torch.cuda.CudaError:", e)
222
+ ret = {
223
+ "text": server_error_msg,
224
+ "error_code": 1,
225
+ }
226
+ return JSONResponse(ret, status_code=404)
227
+ except Exception as e:
228
+ print("Caught Unknown Error", e)
229
+ traceback.print_exc()
230
+ ret = {
231
+ "text": server_error_msg,
232
+ "error_code": 1,
233
+ }
234
+ return JSONResponse(ret, status_code=404)
235
+
236
+
237
+ @app.post("/send")
238
+ async def generate(request: Request):
239
+ logger.info("Worker send...")
240
+ params = await request.json()
241
+ uid = uuid.uuid4()
242
+ threading.Thread(target=worker.generate, args=(uid, params,)).start()
243
+ ret = {"uid": str(uid)}
244
+ return JSONResponse(ret, status_code=200)
245
+
246
+
247
+ @app.get("/status/{uid}")
248
+ async def status(uid: str):
249
+ save_file_path = os.path.join(SAVE_DIR, f'{uid}.glb')
250
+ print(save_file_path, os.path.exists(save_file_path))
251
+ if not os.path.exists(save_file_path):
252
+ response = {'status': 'processing'}
253
+ return JSONResponse(response, status_code=200)
254
+ else:
255
+ base64_str = base64.b64encode(open(save_file_path, 'rb').read()).decode()
256
+ response = {'status': 'completed', 'model_base64': base64_str}
257
+ return JSONResponse(response, status_code=200)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument("--host", type=str, default="0.0.0.0")
263
+ parser.add_argument("--port", type=int, default=8081)
264
+ parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2')
265
+ parser.add_argument("--device", type=str, default="cuda")
266
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
267
+ args = parser.parse_args()
268
+ logger.info(f"args: {args}")
269
+
270
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
271
+
272
+ worker = ModelWorker(model_path=args.model_path, device=args.device)
273
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")