3dcre commited on
Commit
f34eb87
·
verified ·
1 Parent(s): 59783c0

Delete server.py

Browse files
Files changed (1) hide show
  1. server.py +0 -273
server.py DELETED
@@ -1,273 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import base64
7
- import logging
8
- import logging.handlers
9
- import os
10
- import sys
11
- import tempfile
12
- import threading
13
- import traceback
14
- import uuid
15
- from io import BytesIO
16
-
17
- import torch
18
- import trimesh
19
- import uvicorn
20
- from PIL import Image
21
- from fastapi import FastAPI, Request
22
- from fastapi.responses import JSONResponse, FileResponse
23
-
24
- from hy3dgen.rembg import BackgroundRemover
25
- from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FloaterRemover, DegenerateFaceRemover, FaceReducer
26
- from hy3dgen.texgen import Hunyuan3DPaintPipeline
27
- from hy3dgen.text2image import HunyuanDiTPipeline
28
-
29
- LOGDIR = '.'
30
-
31
- server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
32
- moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
33
-
34
- handler = None
35
-
36
-
37
- def build_logger(logger_name, logger_filename):
38
- global handler
39
-
40
- formatter = logging.Formatter(
41
- fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
42
- datefmt="%Y-%m-%d %H:%M:%S",
43
- )
44
-
45
- # Set the format of root handlers
46
- if not logging.getLogger().handlers:
47
- logging.basicConfig(level=logging.INFO)
48
- logging.getLogger().handlers[0].setFormatter(formatter)
49
-
50
- # Redirect stdout and stderr to loggers
51
- stdout_logger = logging.getLogger("stdout")
52
- stdout_logger.setLevel(logging.INFO)
53
- sl = StreamToLogger(stdout_logger, logging.INFO)
54
- sys.stdout = sl
55
-
56
- stderr_logger = logging.getLogger("stderr")
57
- stderr_logger.setLevel(logging.ERROR)
58
- sl = StreamToLogger(stderr_logger, logging.ERROR)
59
- sys.stderr = sl
60
-
61
- # Get logger
62
- logger = logging.getLogger(logger_name)
63
- logger.setLevel(logging.INFO)
64
-
65
- # Add a file handler for all loggers
66
- if handler is None:
67
- os.makedirs(LOGDIR, exist_ok=True)
68
- filename = os.path.join(LOGDIR, logger_filename)
69
- handler = logging.handlers.TimedRotatingFileHandler(
70
- filename, when='D', utc=True, encoding='UTF-8')
71
- handler.setFormatter(formatter)
72
-
73
- for name, item in logging.root.manager.loggerDict.items():
74
- if isinstance(item, logging.Logger):
75
- item.addHandler(handler)
76
-
77
- return logger
78
-
79
-
80
- class StreamToLogger(object):
81
- """
82
- Fake file-like stream object that redirects writes to a logger instance.
83
- """
84
-
85
- def __init__(self, logger, log_level=logging.INFO):
86
- self.terminal = sys.stdout
87
- self.logger = logger
88
- self.log_level = log_level
89
- self.linebuf = ''
90
-
91
- def __getattr__(self, attr):
92
- return getattr(self.terminal, attr)
93
-
94
- def write(self, buf):
95
- temp_linebuf = self.linebuf + buf
96
- self.linebuf = ''
97
- for line in temp_linebuf.splitlines(True):
98
- # From the io.TextIOWrapper docs:
99
- # On output, if newline is None, any '\n' characters written
100
- # are translated to the system default line separator.
101
- # By default sys.stdout.write() expects '\n' newlines and then
102
- # translates them so this is still cross platform.
103
- if line[-1] == '\n':
104
- self.logger.log(self.log_level, line.rstrip())
105
- else:
106
- self.linebuf += line
107
-
108
- def flush(self):
109
- if self.linebuf != '':
110
- self.logger.log(self.log_level, self.linebuf.rstrip())
111
- self.linebuf = ''
112
-
113
-
114
- def pretty_print_semaphore(semaphore):
115
- if semaphore is None:
116
- return "None"
117
- return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
118
-
119
-
120
- SAVE_DIR = 'gradio_cache'
121
- os.makedirs(SAVE_DIR, exist_ok=True)
122
-
123
- worker_id = str(uuid.uuid4())[:6]
124
- logger = build_logger("controller", f"{SAVE_DIR}/controller.log")
125
-
126
-
127
- def load_image_from_base64(image):
128
- return Image.open(BytesIO(base64.b64decode(image)))
129
-
130
-
131
- class ModelWorker:
132
- def __init__(self, model_path='tencent/Hunyuan3D-2', device='cuda'):
133
- self.model_path = model_path
134
- self.worker_id = worker_id
135
- self.device = device
136
- logger.info(f"Loading the model {model_path} on worker {worker_id} ...")
137
-
138
- self.rembg = BackgroundRemover()
139
- self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path, device=device)
140
- self.pipeline_t2i = HunyuanDiTPipeline('Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled',
141
- device=device)
142
- self.pipeline_tex = Hunyuan3DPaintPipeline.from_pretrained(model_path)
143
-
144
- def get_queue_length(self):
145
- if model_semaphore is None:
146
- return 0
147
- else:
148
- return args.limit_model_concurrency - model_semaphore._value + (len(
149
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
150
-
151
- def get_status(self):
152
- return {
153
- "speed": 1,
154
- "queue_length": self.get_queue_length(),
155
- }
156
-
157
- @torch.inference_mode()
158
- def generate(self, uid, params):
159
- if 'image' in params:
160
- image = params["image"]
161
- image = load_image_from_base64(image)
162
- else:
163
- if 'text' in params:
164
- text = params["text"]
165
- image = self.pipeline_t2i(text)
166
- else:
167
- raise ValueError("No input image or text provided")
168
-
169
- image = self.rembg(image)
170
- params['image'] = image
171
-
172
- if 'mesh' in params:
173
- mesh = trimesh.load(BytesIO(base64.b64decode(params["mesh"])), file_type='glb')
174
- else:
175
- seed = params.get("seed", 1234)
176
- params['generator'] = torch.Generator(self.device).manual_seed(seed)
177
- params['octree_resolution'] = params.get("octree_resolution", 256)
178
- params['num_inference_steps'] = params.get("num_inference_steps", 30)
179
- params['guidance_scale'] = params.get('guidance_scale', 7.5)
180
- params['mc_algo'] = 'mc'
181
- mesh = self.pipeline(**params)[0]
182
-
183
- if params.get('texture', False):
184
- mesh = FloaterRemover()(mesh)
185
- mesh = DegenerateFaceRemover()(mesh)
186
- mesh = FaceReducer()(mesh, max_facenum=params.get('face_count', 40000))
187
- mesh = self.pipeline_tex(mesh, image)
188
-
189
- with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as temp_file:
190
- mesh.export(temp_file.name)
191
- mesh = trimesh.load(temp_file.name)
192
- temp_file.close()
193
- os.unlink(temp_file.name)
194
- save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb')
195
- mesh.export(save_path)
196
-
197
- torch.cuda.empty_cache()
198
- return save_path, uid
199
-
200
-
201
- app = FastAPI()
202
-
203
-
204
- @app.post("/generate")
205
- async def generate(request: Request):
206
- logger.info("Worker generating...")
207
- params = await request.json()
208
- uid = uuid.uuid4()
209
- try:
210
- file_path, uid = worker.generate(uid, params)
211
- return FileResponse(file_path)
212
- except ValueError as e:
213
- traceback.print_exc()
214
- print("Caught ValueError:", e)
215
- ret = {
216
- "text": server_error_msg,
217
- "error_code": 1,
218
- }
219
- return JSONResponse(ret, status_code=404)
220
- except torch.cuda.CudaError as e:
221
- print("Caught torch.cuda.CudaError:", e)
222
- ret = {
223
- "text": server_error_msg,
224
- "error_code": 1,
225
- }
226
- return JSONResponse(ret, status_code=404)
227
- except Exception as e:
228
- print("Caught Unknown Error", e)
229
- traceback.print_exc()
230
- ret = {
231
- "text": server_error_msg,
232
- "error_code": 1,
233
- }
234
- return JSONResponse(ret, status_code=404)
235
-
236
-
237
- @app.post("/send")
238
- async def generate(request: Request):
239
- logger.info("Worker send...")
240
- params = await request.json()
241
- uid = uuid.uuid4()
242
- threading.Thread(target=worker.generate, args=(uid, params,)).start()
243
- ret = {"uid": str(uid)}
244
- return JSONResponse(ret, status_code=200)
245
-
246
-
247
- @app.get("/status/{uid}")
248
- async def status(uid: str):
249
- save_file_path = os.path.join(SAVE_DIR, f'{uid}.glb')
250
- print(save_file_path, os.path.exists(save_file_path))
251
- if not os.path.exists(save_file_path):
252
- response = {'status': 'processing'}
253
- return JSONResponse(response, status_code=200)
254
- else:
255
- base64_str = base64.b64encode(open(save_file_path, 'rb').read()).decode()
256
- response = {'status': 'completed', 'model_base64': base64_str}
257
- return JSONResponse(response, status_code=200)
258
-
259
-
260
- if __name__ == "__main__":
261
- parser = argparse.ArgumentParser()
262
- parser.add_argument("--host", type=str, default="0.0.0.0")
263
- parser.add_argument("--port", type=int, default=8081)
264
- parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2')
265
- parser.add_argument("--device", type=str, default="cuda")
266
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
267
- args = parser.parse_args()
268
- logger.info(f"args: {args}")
269
-
270
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
271
-
272
- worker = ModelWorker(model_path=args.model_path, device=args.device)
273
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")