maple-shaft commited on
Commit
88ef5ec
·
verified ·
1 Parent(s): 31e32d0

Only mv endpoint handler

Browse files
Files changed (1) hide show
  1. handler.py +126 -285
handler.py CHANGED
@@ -1,285 +1,126 @@
1
- # This is a custom handler module for the forked HF repo maple-shaft/zero123plus-v1.2
2
- # Inference Endpoint hosting on HF will require this file and requirements.txt to be uploaded to the repo in the root.
3
-
4
- from typing import Dict, List, Any
5
- import os
6
- import gc
7
- import psutil
8
- import torch
9
- import base64
10
- import io
11
- from PIL import Image
12
- import trimesh
13
- import tempfile
14
- import pymeshlab as ml
15
- from hy3dgen.rembg import BackgroundRemover
16
- from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline
17
- from hy3dgen.texgen import Hunyuan3DPaintPipeline
18
- from diffusers.pipelines.auto_pipeline import AutoPipelineForText2Image
19
- from diffusers import DiffusionPipeline # pyright: ignore[reportPrivateImportUsage]
20
-
21
- def log_ram(tag):
22
- rss = psutil.Process(os.getpid()).memory_info().rss / (1024**3)
23
- print(f"[{tag}] RSS: {rss:.2f} GB", flush=True)
24
-
25
- class HFMultiViewGen:
26
-
27
- def __init__(self,
28
- hf_token: str,
29
- mv_model: str = "maple-shaft/zero123plus-v1.2",
30
- mv_custom_pipeline: str = "sudo-ai/zero123plus-pipeline",
31
- gen_custom_pipeline: str = "",
32
- debug: bool = False):
33
- self.debug = debug
34
- self.hf_token = hf_token
35
- self.mv_model = mv_model
36
- self.mv_custom_pipeline = mv_custom_pipeline
37
-
38
- self.img_to_mesh_model_parent_name = "tencent/Hunyuan3D-2"
39
- self.img_to_mesh_model_name = "tencent/Hunyuan3D-2mv"
40
- self.img_to_mesh_sub_name = "hunyuan3d-dit-v2-mv-turbo"
41
- self.mesh_paint_sub_name = "hunyuan3d-paint-v2-0-turbo"
42
- self.mesh_delight_sub_name = "hunyuan3d-delight-v2-0"
43
- self.mesh_vae_sub_name = "hunyuan3d-vae-v2-0-turbo"
44
-
45
- print(f"torch.cuda.is_available() = {torch.cuda.is_available()}")
46
- torch.cuda.synchronize()
47
- print("GPU SYNC OK", flush=True)
48
-
49
- self.pipe = DiffusionPipeline.from_pretrained(
50
- self.mv_model,
51
- token=self.hf_token,
52
- custom_pipeline=self.mv_custom_pipeline,
53
- torch_dtype=torch.float16,
54
- trust_remote_code=True
55
- )
56
-
57
- self.mesh_pipe: Hunyuan3DDiTFlowMatchingPipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
58
- self.img_to_mesh_model_name,
59
- subfolder=self.img_to_mesh_sub_name,
60
- variant='fp16',
61
- )
62
-
63
- self.tex_pipe = Hunyuan3DPaintPipeline.from_pretrained(
64
- self.img_to_mesh_model_parent_name
65
- )
66
- self.tex_pipe.config.render_size = 1024
67
- self.tex_pipe.config.texture_size = 1024
68
- self.tex_pipe.render.set_default_render_resolution(self.tex_pipe.config.render_size)
69
- self.tex_pipe.render.set_default_texture_resolution(self.tex_pipe.config.texture_size)
70
-
71
- def preprocess_images_for_mesh(self, images: dict[str, Image.Image]) -> dict[str, Image.Image]:
72
- ret = {}
73
- for k, v in images.items():
74
- if v.mode == 'RGB':
75
- rembg = BackgroundRemover()
76
- v = rembg(v)
77
- ret[k] = v.resize((512,512), Image.LANCZOS).convert("RGBA")
78
- return ret
79
-
80
- def free_gpu(self, pipe):
81
- log_ram("before free_gpu")
82
- gc.collect()
83
- pipe.to("cpu")
84
- torch.cuda.empty_cache()
85
- torch.cuda.ipc_collect()
86
- torch.cuda.synchronize()
87
- log_ram("after free_gpu")
88
-
89
- def allocate_gpu(self, pipe):
90
- log_ram("before allocate_gpu")
91
- pipe.to("cuda")
92
- torch.cuda.synchronize()
93
- log_ram("after allocate_gpu")
94
-
95
- def simplify_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh | None:
96
- obj_bytes = mesh.export(file_type="ply")
97
- ms = ml.MeshSet()
98
- tf = None
99
- remeshed_tf = None
100
- try:
101
- tf = tempfile.NamedTemporaryFile(delete=False, suffix=".ply")
102
- tf.write(obj_bytes)
103
- tf.flush()
104
- ms.load_new_mesh(tf.name)
105
- # Step 1: Optional smoothing (to mimic voxel smooth effect)
106
- ms.apply_filter(
107
- "apply_coord_laplacian_smoothing",
108
- stepsmoothnum=3
109
- )
110
-
111
- # Step 2: Uniform resampling for smooth remeshing
112
- # This is the closest PyMeshLab has to Blender's smooth voxel remesh.
113
- ms.apply_filter(
114
- "generate_resampled_uniform_mesh",
115
- cellsize=ml.PureValue(ms.current_mesh().bounding_box().diagonal() / (2 ** 5)), # roughly matches octree depth=5
116
- offset=ml.PureValue(0.0),
117
- multisample=True
118
- )
119
-
120
- # Step 3: Optional shrink/scale adjustment (Blender’s scale=0.9)
121
- #ms.apply_filter("transform_scale_normalize", scalefactor=0.9)
122
-
123
- # Step 4: Remove small disconnected pieces
124
- ms.apply_filter("compute_selection_by_small_disconnected_components_per_face")
125
- ms.apply_filter("meshing_remove_selected_vertices_and_faces")
126
- # Step 5: (Optional) Smooth again to even out voxel edges
127
- ms.apply_filter("apply_coord_taubin_smoothing", stepsmoothnum=10, lambda_=0.5, mu=-0.53)
128
-
129
- remeshed_tf = tempfile.NamedTemporaryFile(delete=False, suffix=".ply")
130
- ms.save_current_mesh(remeshed_tf.name)
131
- remeshed_tf.flush()
132
- remeshed: trimesh.Trimesh = trimesh.load_mesh(remeshed_tf, file_type="ply")
133
- remeshed = remeshed.process(validate=True, merge_norm=True)
134
-
135
- print(f"is_watertight = {remeshed.is_watertight}", flush=True)
136
- print(f"is_volume = {remeshed.is_volume}", flush=True)
137
- print(f"euler_number = {remeshed.euler_number}", flush=True)
138
- return remeshed
139
- except Exception as e:
140
- print(e)
141
- finally:
142
- if tf:
143
- tf.close()
144
- os.remove(tf.name)
145
- del tf
146
- if remeshed_tf:
147
- remeshed_tf.close()
148
- os.remove(remeshed_tf.name)
149
- del remeshed_tf
150
-
151
-
152
- def generate_multiview(self, initial: Image.Image) -> dict[str, Image.Image]:
153
- print(">>> generate_multiview", flush=True)
154
-
155
- self.free_gpu(self.mesh_pipe)
156
- self.free_gpu(self.tex_pipe)
157
- self.allocate_gpu(self.pipe)
158
-
159
- print("allocated second pipe to gpu", flush=True)
160
- # --- prepare image properly ---
161
- img = initial.convert("RGB")
162
-
163
- print("converted the image to RGB", flush=True)
164
-
165
- mv_result : List[Image.Image] = self.pipe(
166
- image=img,
167
- width=640,
168
- height=960,
169
- num_inference_steps=28,
170
- guidance_scale=4.0,
171
- num_images_per_prompt=1
172
- ).images # pyright: ignore[reportCallIssue]
173
-
174
- print("mv_result", repr(mv_result), flush=True)
175
-
176
- # The resulting file comes back as a 2x3 tiled PNG image, we will need to split it into a set of images
177
- tile_w = 320.0 # img.width / 2.0
178
- tile_h = 320.0 # img.height / 3.0
179
- right_tile = (tile_w, 0.0, tile_w * 2.0, tile_h)
180
- back_tile = (tile_w, tile_h, tile_w * 2.0, tile_h * 2.0)
181
- left_tile = (0, tile_h * 2.0, tile_w, tile_h * 3.0)
182
- ret = {
183
- "front": img,
184
- "right": mv_result[0].crop(right_tile),
185
- "back": mv_result[0].crop(back_tile),
186
- "left": mv_result[0].crop(left_tile)
187
- }
188
-
189
- return ret
190
-
191
- def create_mesh(self, images: dict[str, Image.Image]) -> trimesh.Trimesh | None:
192
- print(">>> Entered create_mesh", flush=True)
193
-
194
- self.free_gpu(self.pipe)
195
- self.free_gpu(self.tex_pipe)
196
- self.allocate_gpu(self.mesh_pipe)
197
-
198
- timages = self.preprocess_images_for_mesh(images)
199
-
200
- # Mesh Pipeline
201
- mesh: trimesh.Trimesh = self.mesh_pipe(
202
- image=timages,
203
- num_inference_steps=10,
204
- octree_resolution=120,
205
- num_chunks=2000,
206
- output_type='trimesh'
207
- )[0]
208
- simplified_mesh = self.simplify_mesh(mesh)
209
- return simplified_mesh
210
-
211
- def texture_mesh(self, mesh: trimesh.Trimesh, preprocessed_front_image: Image.Image) -> trimesh.Trimesh | None:
212
- print(">>> call texture_mesh", flush=True)
213
-
214
- self.free_gpu(self.pipe)
215
- self.free_gpu(self.mesh_pipe)
216
- self.allocate_gpu(self.tex_pipe)
217
-
218
- return self.tex_pipe(mesh=mesh, image=preprocessed_front_image)
219
-
220
- class EndpointHandler():
221
- def __init__(self, path=""):
222
- self.hf_token = os.environ["HUGGINGFACE_TOKEN"]
223
- self.hf_gen = HFMultiViewGen(hf_token=self.hf_token)
224
-
225
- def convert(self, fromval: dict[str, Image.Image]) -> dict[str, str]:
226
- ret: dict[str, str] = {}
227
- for k,v in fromval.items():
228
- with io.BytesIO() as bio:
229
- v.save(bio, format="PNG")
230
- ret[k] = base64.b64encode(bio.getvalue()).decode()
231
-
232
- return ret
233
-
234
- def convert_img(self, fromval: str) -> Image.Image:
235
- try:
236
- print(">>> convert_img", flush=True)
237
- with io.BytesIO(base64.b64decode(fromval)) as bio:
238
- return Image.open(bio.getvalue())
239
- except Exception as e:
240
- print("Error", repr(e), flush=True)
241
- raise e
242
-
243
- def convert_mesh(self, fromval: trimesh.Trimesh) -> str | None:
244
- print(">>> call convert_mesh", flush=True)
245
- try:
246
- ret: str | None = None
247
- tf = tempfile.NamedTemporaryFile("w+b", suffix=".glb", delete=False)
248
- tf_name: str = tf.name
249
- fromval.export(tf.name)
250
- tf.flush()
251
- tf.close()
252
- with open(tf_name, "r+b") as f:
253
- ret = base64.b64encode(f.read()).decode()
254
- os.remove(tf.name)
255
- return ret
256
- except Exception as e:
257
- print("Error", repr(e), flush=True)
258
- raise e
259
-
260
- def __call__(self, data: Dict[str, Any]):
261
- print("Entered __call__!!! ", repr(data), flush=True)
262
- ret: dict[str, str] = {}
263
- try:
264
- img_str = data['inputs']
265
- print(f"Initial image: {img_str}", flush=True)
266
- img: Image.Image = self.convert_img(fromval=img_str)
267
- print("Converted to image", repr(img), flush=True)
268
- mv: dict[str, Image.Image] = self.hf_gen.generate_multiview(initial=img)
269
- print(f"Mv Image: {mv}", flush=True)
270
- mesh: trimesh.Trimesh | None = self.hf_gen.create_mesh(images=mv)
271
- print(f"Created to mesh: {mesh}", flush=True)
272
- if not mesh:
273
- raise Exception("No mesh")
274
- mesh = self.hf_gen.texture_mesh(mesh=mesh, preprocessed_front_image=img)
275
- print(f"Textured mesh: {mesh}", flush=True)
276
- if not mesh:
277
- raise Exception("No mesh")
278
- output: str | None = self.convert_mesh(fromval=mesh)
279
- if not output:
280
- raise Exception("No output")
281
- ret["output"] = output
282
- return ret
283
- except Exception as e:
284
- print(e)
285
- raise e
 
1
+ from typing import Dict, List, Any
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ import dotenv
6
+ import base64
7
+ import io
8
+ from diffusers import DiffusionPipeline # pyright: ignore[reportPrivateImportUsage]
9
+
10
+ dotenv.load_dotenv()
11
+
12
+ def convert_b64_to_image(from_str: str) -> Image.Image:
13
+ print(">>> call convert_b64_to_image", flush=True)
14
+ try:
15
+ data: bytes = base64.b64decode(from_str)
16
+ with io.BytesIO(data) as bio:
17
+ imgfile = Image.open(bio, formats=["PNG"])
18
+ imgfile.load()
19
+ return imgfile
20
+
21
+ except Exception as e:
22
+ print(e, flush=True)
23
+ raise e
24
+
25
+ def convert_image_to_b64(from_img: Image.Image) -> str:
26
+ print(">>> call convert_image_to_b64", flush=True)
27
+ try:
28
+ with io.BytesIO() as buffer:
29
+ from_img.save(buffer, format="PNG")
30
+ byte_data: bytes = buffer.getvalue()
31
+ return base64.b64encode(byte_data).decode("utf-8")
32
+ except Exception as e:
33
+ print(e, flush=True)
34
+ raise e
35
+
36
+ class HFMultiViewGen:
37
+
38
+ def __init__(self,
39
+ hf_token: str,
40
+ mv_model: str = "maple-shaft/zero123plus-v1.2",
41
+ mv_custom_pipeline: str = "sudo-ai/zero123plus-pipeline",
42
+ gen_custom_pipeline: str = "",
43
+ repo_dir: str = "/repository",
44
+ debug: bool = False):
45
+ self.debug = debug
46
+ self.hf_token = hf_token
47
+ self.mv_model = mv_model
48
+ self.mv_custom_pipeline = mv_custom_pipeline
49
+ self.repo_dir = repo_dir
50
+
51
+ print(f"torch.cuda.is_available() = {torch.cuda.is_available()}")
52
+ torch.cuda.synchronize()
53
+ print("GPU SYNC OK", flush=True)
54
+
55
+ self.pipe = DiffusionPipeline.from_pretrained(
56
+ self.mv_model,
57
+ cache_dir=self.repo_dir,
58
+ token=self.hf_token,
59
+ custom_pipeline=self.mv_custom_pipeline,
60
+ dtype=torch.float16
61
+ ).to("cuda")
62
+
63
+ def generate_multiview(self, initial: Image.Image) -> dict[str, Image.Image]:
64
+ print(">>> generate_multiview", flush=True)
65
+
66
+ print("allocated second pipe to gpu", flush=True)
67
+ # --- prepare image properly ---
68
+ img = initial.convert("RGB")
69
+
70
+ print("converted the image to RGB", flush=True)
71
+
72
+ mv_result : List[Image.Image] = self.pipe(
73
+ image=img,
74
+ width=640,
75
+ height=960,
76
+ num_inference_steps=28,
77
+ guidance_scale=4.0,
78
+ num_images_per_prompt=1
79
+ ).images # pyright: ignore[reportCallIssue]
80
+
81
+ print("mv_result", repr(mv_result), flush=True)
82
+
83
+ # The resulting file comes back as a 2x3 tiled PNG image, we will need to split it into a set of images
84
+ tile_w = 320.0 # img.width / 2.0
85
+ tile_h = 320.0 # img.height / 3.0
86
+ right_tile = (tile_w, 0.0, tile_w * 2.0, tile_h)
87
+ back_tile = (tile_w, tile_h, tile_w * 2.0, tile_h * 2.0)
88
+ left_tile = (0, tile_h * 2.0, tile_w, tile_h * 3.0)
89
+ ret = {
90
+ "front": img,
91
+ "right": mv_result[0].crop(right_tile),
92
+ "back": mv_result[0].crop(back_tile),
93
+ "left": mv_result[0].crop(left_tile)
94
+ }
95
+
96
+ return ret
97
+
98
+ class EndpointHandler():
99
+ def __init__(self, path=""):
100
+ self.hf_token = os.environ["HUGGINGFACE_TOKEN"]
101
+ self.repo_dir = os.environ["HF_HUB_CACHE"]
102
+ self.hf_gen = HFMultiViewGen(hf_token=self.hf_token, repo_dir=self.repo_dir)
103
+
104
+ def convert(self, fromval: dict[str, Image.Image]) -> dict[str, str]:
105
+ ret: dict[str, str] = {}
106
+ for k,v in fromval.items():
107
+ ret[k] = convert_image_to_b64(v)
108
+ return ret
109
+
110
+ def __call__(self, data: Dict[str, Any]):
111
+ print("Entered __call__!!! ", repr(data), flush=True)
112
+ ret: dict[str, Any] = {}
113
+ try:
114
+ img_str = data['inputs']
115
+ print(f"Initial image: {img_str}", flush=True)
116
+ img: Image.Image = convert_b64_to_image(img_str)
117
+ print("Converted to image", repr(img), flush=True)
118
+ mv: dict[str, Image.Image] = self.hf_gen.generate_multiview(initial=img)
119
+ print(f"Mv Image: {mv}", flush=True)
120
+ mv_str: Dict[str,str] = self.convert(mv)
121
+ ret["output"] = mv_str
122
+ return ret
123
+ except Exception as e:
124
+ print(e)
125
+ raise e
126
+