LejobuildYT commited on
Commit
c4eb67e
·
verified ·
1 Parent(s): 63b49f9

Update hy3dgen/texgen/pipelines.py

Browse files
Files changed (1) hide show
  1. hy3dgen/texgen/pipelines.py +119 -89
hy3dgen/texgen/pipelines.py CHANGED
@@ -24,16 +24,29 @@ from pathlib import Path
24
  from .differentiable_renderer.mesh_render import MeshRender
25
  from .utils.dehighlight_utils import Light_Shadow_Remover
26
  from .utils.multiview_utils import Multiview_Diffusion_Net
27
- from .utils.imagesuper_utils import Image_Super_Net
28
  from .utils.uv_warp_utils import mesh_uv_wrap
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- class Hunyuan3DTexGenConfig:
34
 
 
35
  def __init__(self, light_remover_ckpt_path, multiview_ckpt_path):
36
- self.device = 'cuda'
 
 
 
37
  self.light_remover_ckpt_path = light_remover_ckpt_path
38
  self.multiview_ckpt_path = multiview_ckpt_path
39
 
@@ -52,84 +65,94 @@ class Hunyuan3DPaintPipeline:
52
  def from_pretrained(cls, model_path):
53
  original_model_path = model_path
54
  print(f"原始路径 original_model_path: {model_path}")
 
55
  if not os.path.exists(model_path):
56
- print(f"存在原始路径: {model_path}")
57
- # try local path
58
  base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
59
  model_path = os.path.expanduser(os.path.join(base_dir, model_path))
60
 
61
- print(f"基础路径 base_dir: {base_dir}")
62
- print(f"模型路径 model_path: {model_path}")
63
-
64
  delight_model_path = os.path.join(model_path, 'hunyuan3d-delight-v2-0')
65
  multiview_model_path = os.path.join(model_path, 'hunyuan3d-paint-v2-0')
66
 
67
- print(f"路径 delight_model_path: {delight_model_path}")
68
- print(f"路径 multiview_model_path: {multiview_model_path}")
69
-
70
  if not os.path.exists(delight_model_path) or not os.path.exists(multiview_model_path):
71
  try:
72
  import huggingface_hub
73
- # download from huggingface
74
- model_path = huggingface_hub.snapshot_download(repo_id=original_model_path,
75
- allow_patterns=["hunyuan3d-delight-v2-0/*"])
76
- print(f"下载的 model_path 1 : {model_path}")
77
- snapshot_path = Path(model_path)
78
- for path in snapshot_path.rglob("*"):
79
- print(path.relative_to(snapshot_path))
80
- model_path = huggingface_hub.snapshot_download(repo_id=original_model_path,
81
- allow_patterns=["hunyuan3d-paint-v2-0/*"])
82
- print(f"下载的 model_path 2 : {model_path}")
83
- snapshot_path = Path(model_path)
84
- for path in snapshot_path.rglob("*"):
85
- print(path.relative_to(snapshot_path))
86
-
87
-
88
  delight_model_path = os.path.join(model_path, 'hunyuan3d-delight-v2-0')
89
  multiview_model_path = os.path.join(model_path, 'hunyuan3d-paint-v2-0')
90
-
91
- print(f"路径 delight_model_path : {delight_model_path}")
92
- print(f"路径 multiview_model_path : {multiview_model_path}")
93
- print(f"路径 delight_model_path 是否存在: {os.path.exists(delight_model_path)}")
94
- print(f"路径 multiview_model_path 是否存在: {os.path.exists(multiview_model_path)}")
95
 
96
  return cls(Hunyuan3DTexGenConfig(delight_model_path, multiview_model_path))
 
97
  except Exception as e:
98
  print("构造 Hunyuan3DPaintPipeline 实例时出错:", e)
99
- import traceback
100
- traceback.print_exc()
101
  raise
102
- # except ImportError:
103
- # logger.warning(
104
- # "You need to install HuggingFace Hub to load models from the hub."
105
- # )
106
- # raise RuntimeError(f"Model path {model_path} not found")
107
  else:
108
  return cls(Hunyuan3DTexGenConfig(delight_model_path, multiview_model_path))
109
 
110
- raise FileNotFoundError(f"Model path {original_model_path} not found and we could not find it at huggingface")
 
111
 
112
  def __init__(self, config):
113
  self.config = config
114
  self.models = {}
 
115
  self.render = MeshRender(
116
  default_resolution=self.config.render_size,
117
- texture_size=self.config.texture_size)
 
118
 
119
  self.load_models()
120
 
 
 
 
 
121
  def load_models(self):
122
- # empty cude cache
123
- torch.cuda.empty_cache()
124
- # Load model
 
 
 
 
125
  self.models['delight_model'] = Light_Shadow_Remover(self.config)
126
  self.models['multiview_model'] = Multiview_Diffusion_Net(self.config)
127
  # self.models['super_model'] = Image_Super_Net(self.config)
128
 
129
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
130
- self.models['delight_model'].pipeline.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
131
- self.models['multiview_model'].pipeline.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def render_normal_multiview(self, camera_elevs, camera_azims, use_abs_coor=True):
134
  normal_maps = []
135
  for elev, azim in zip(camera_elevs, camera_azims):
@@ -139,23 +162,28 @@ class Hunyuan3DPaintPipeline:
139
 
140
  return normal_maps
141
 
 
142
  def render_position_multiview(self, camera_elevs, camera_azims):
143
  position_maps = []
144
  for elev, azim in zip(camera_elevs, camera_azims):
145
  position_map = self.render.render_position(
146
  elev, azim, return_type='pl')
147
  position_maps.append(position_map)
148
-
149
  return position_maps
150
 
 
151
  def bake_from_multiview(self, views, camera_elevs,
152
  camera_azims, view_weights, method='graphcut'):
153
  project_textures, project_weighted_cos_maps = [], []
154
  project_boundary_maps = []
 
155
  for view, camera_elev, camera_azim, weight in zip(
156
- views, camera_elevs, camera_azims, view_weights):
 
157
  project_texture, project_cos_map, project_boundary_map = self.render.back_project(
158
- view, camera_elev, camera_azim)
 
 
159
  project_cos_map = weight * (project_cos_map ** self.config.bake_exp)
160
  project_textures.append(project_texture)
161
  project_weighted_cos_maps.append(project_cos_map)
@@ -166,8 +194,10 @@ class Hunyuan3DPaintPipeline:
166
  project_textures, project_weighted_cos_maps)
167
  else:
168
  raise f'no method {method}'
 
169
  return texture, ori_trust_map > 1E-8
170
 
 
171
  def texture_inpaint(self, texture, mask):
172
 
173
  texture_np = self.render.uv_inpaint(texture, mask)
@@ -175,39 +205,36 @@ class Hunyuan3DPaintPipeline:
175
 
176
  return texture
177
 
 
178
  def recenter_image(self, image, border_ratio=0.2):
179
  if image.mode == 'RGB':
180
  return image
181
  elif image.mode == 'L':
182
- image = image.convert('RGB')
183
- return image
184
 
185
- alpha_channel = np.array(image)[:, :, 3]
186
- non_zero_indices = np.argwhere(alpha_channel > 0)
187
- if non_zero_indices.size == 0:
188
- raise ValueError("Image is fully transparent")
189
 
190
- min_row, min_col = non_zero_indices.min(axis=0)
191
- max_row, max_col = non_zero_indices.max(axis=0)
192
 
193
- cropped_image = image.crop((min_col, min_row, max_col + 1, max_row + 1))
194
 
195
- width, height = cropped_image.size
196
- border_width = int(width * border_ratio)
197
- border_height = int(height * border_ratio)
198
 
199
- new_width = width + 2 * border_width
200
- new_height = height + 2 * border_height
 
201
 
202
- square_size = max(new_width, new_height)
 
203
 
204
- new_image = Image.new('RGBA', (square_size, square_size), (255, 255, 255, 0))
205
 
206
- paste_x = (square_size - new_width) // 2 + border_width
207
- paste_y = (square_size - new_height) // 2 + border_height
208
-
209
- new_image.paste(cropped_image, (paste_x, paste_y))
210
- return new_image
211
 
212
  @torch.no_grad()
213
  def __call__(self, mesh, image):
@@ -219,39 +246,42 @@ class Hunyuan3DPaintPipeline:
219
 
220
  image_prompt = self.recenter_image(image_prompt)
221
 
 
222
  image_prompt = self.models['delight_model'](image_prompt)
223
 
224
  mesh = mesh_uv_wrap(mesh)
225
-
226
  self.render.load_mesh(mesh)
227
 
228
- selected_camera_elevs, selected_camera_azims, selected_view_weights = \
229
- self.config.candidate_camera_elevs, self.config.candidate_camera_azims, self.config.candidate_view_weights
 
230
 
231
- normal_maps = self.render_normal_multiview(
232
- selected_camera_elevs, selected_camera_azims, use_abs_coor=True)
233
- position_maps = self.render_position_multiview(
234
- selected_camera_elevs, selected_camera_azims)
235
 
236
- camera_info = [(((azim // 30) + 9) % 12) // {-20: 1, 0: 1, 20: 1, -90: 3, 90: 3}[
237
- elev] + {-20: 0, 0: 12, 20: 24, -90: 36, 90: 40}[elev] for azim, elev in
238
- zip(selected_camera_azims, selected_camera_elevs)]
239
- multiviews = self.models['multiview_model'](image_prompt, normal_maps + position_maps, camera_info)
 
 
 
 
 
 
240
 
241
  for i in range(len(multiviews)):
242
- # multiviews[i] = self.models['super_model'](multiviews[i])
243
  multiviews[i] = multiviews[i].resize(
244
- (self.config.render_size, self.config.render_size))
 
245
 
246
- texture, mask = self.bake_from_multiview(multiviews,
247
- selected_camera_elevs, selected_camera_azims, selected_view_weights,
248
- method=self.config.merge_method)
249
 
250
  mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
251
 
252
  texture = self.texture_inpaint(texture, mask_np)
253
-
254
  self.render.set_texture(texture)
255
- textured_mesh = self.render.save_mesh()
256
 
257
- return textured_mesh
 
24
  from .differentiable_renderer.mesh_render import MeshRender
25
  from .utils.dehighlight_utils import Light_Shadow_Remover
26
  from .utils.multiview_utils import Multiview_Diffusion_Net
27
+ # from .utils.imagesuper_utils import Image_Super_Net
28
  from .utils.uv_warp_utils import mesh_uv_wrap
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
+ # -------------------------------------------
33
+ # Device Selection (Global clean handling)
34
+ # -------------------------------------------
35
+
36
+ def get_best_device():
37
+ if torch.cuda.is_available():
38
+ return "cuda"
39
+ if torch.backends.mps.is_available():
40
+ return "mps"
41
+ return "cpu"
42
 
 
43
 
44
+ class Hunyuan3DTexGenConfig:
45
  def __init__(self, light_remover_ckpt_path, multiview_ckpt_path):
46
+
47
+ # Old: self.device = 'cuda'
48
+ self.device = get_best_device()
49
+
50
  self.light_remover_ckpt_path = light_remover_ckpt_path
51
  self.multiview_ckpt_path = multiview_ckpt_path
52
 
 
65
  def from_pretrained(cls, model_path):
66
  original_model_path = model_path
67
  print(f"原始路径 original_model_path: {model_path}")
68
+
69
  if not os.path.exists(model_path):
70
+
71
+ print(f"不存在原始路径: {model_path}")
72
  base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
73
  model_path = os.path.expanduser(os.path.join(base_dir, model_path))
74
 
 
 
 
75
  delight_model_path = os.path.join(model_path, 'hunyuan3d-delight-v2-0')
76
  multiview_model_path = os.path.join(model_path, 'hunyuan3d-paint-v2-0')
77
 
 
 
 
78
  if not os.path.exists(delight_model_path) or not os.path.exists(multiview_model_path):
79
  try:
80
  import huggingface_hub
81
+
82
+ model_path = huggingface_hub.snapshot_download(
83
+ repo_id=original_model_path,
84
+ allow_patterns=["hunyuan3d-delight-v2-0/*"]
85
+ )
86
+ model_path = huggingface_hub.snapshot_download(
87
+ repo_id=original_model_path,
88
+ allow_patterns=["hunyuan3d-paint-v2-0/*"]
89
+ )
90
+
 
 
 
 
 
91
  delight_model_path = os.path.join(model_path, 'hunyuan3d-delight-v2-0')
92
  multiview_model_path = os.path.join(model_path, 'hunyuan3d-paint-v2-0')
 
 
 
 
 
93
 
94
  return cls(Hunyuan3DTexGenConfig(delight_model_path, multiview_model_path))
95
+
96
  except Exception as e:
97
  print("构造 Hunyuan3DPaintPipeline 实例时出错:", e)
 
 
98
  raise
99
+
 
 
 
 
100
  else:
101
  return cls(Hunyuan3DTexGenConfig(delight_model_path, multiview_model_path))
102
 
103
+ raise FileNotFoundError(f"Model path {original_model_path} not found and Hub download failed.")
104
+
105
 
106
  def __init__(self, config):
107
  self.config = config
108
  self.models = {}
109
+
110
  self.render = MeshRender(
111
  default_resolution=self.config.render_size,
112
+ texture_size=self.config.texture_size
113
+ )
114
 
115
  self.load_models()
116
 
117
+
118
+ # -------------------------------------------
119
+ # Load Models — Dynamic CUDA handling
120
+ # -------------------------------------------
121
  def load_models(self):
122
+
123
+ # Originally forced CUDA:
124
+ # torch.cuda.empty_cache()
125
+
126
+ if torch.cuda.is_available():
127
+ torch.cuda.empty_cache()
128
+
129
  self.models['delight_model'] = Light_Shadow_Remover(self.config)
130
  self.models['multiview_model'] = Multiview_Diffusion_Net(self.config)
131
  # self.models['super_model'] = Image_Super_Net(self.config)
132
 
 
 
 
133
 
134
+ def enable_model_cpu_offload(
135
+ self,
136
+ gpu_id: Optional[int] = None,
137
+ device: Union[torch.device, str] = None
138
+ ):
139
+ if device is None:
140
+ device = self.config.device
141
+
142
+ if hasattr(self.models['delight_model'], "pipeline"):
143
+ self.models['delight_model'].pipeline.enable_model_cpu_offload(
144
+ gpu_id=gpu_id, device=device
145
+ )
146
+
147
+ if hasattr(self.models['multiview_model'], "pipeline"):
148
+ self.models['multiview_model'].pipeline.enable_model_cpu_offload(
149
+ gpu_id=gpu_id, device=device
150
+ )
151
+
152
+
153
+ # -------------------------------------------
154
+ # Rendering functions unchanged
155
+ # -------------------------------------------
156
  def render_normal_multiview(self, camera_elevs, camera_azims, use_abs_coor=True):
157
  normal_maps = []
158
  for elev, azim in zip(camera_elevs, camera_azims):
 
162
 
163
  return normal_maps
164
 
165
+
166
  def render_position_multiview(self, camera_elevs, camera_azims):
167
  position_maps = []
168
  for elev, azim in zip(camera_elevs, camera_azims):
169
  position_map = self.render.render_position(
170
  elev, azim, return_type='pl')
171
  position_maps.append(position_map)
 
172
  return position_maps
173
 
174
+
175
  def bake_from_multiview(self, views, camera_elevs,
176
  camera_azims, view_weights, method='graphcut'):
177
  project_textures, project_weighted_cos_maps = [], []
178
  project_boundary_maps = []
179
+
180
  for view, camera_elev, camera_azim, weight in zip(
181
+ views, camera_elevs, camera_azims, view_weights
182
+ ):
183
  project_texture, project_cos_map, project_boundary_map = self.render.back_project(
184
+ view, camera_elev, camera_azim
185
+ )
186
+
187
  project_cos_map = weight * (project_cos_map ** self.config.bake_exp)
188
  project_textures.append(project_texture)
189
  project_weighted_cos_maps.append(project_cos_map)
 
194
  project_textures, project_weighted_cos_maps)
195
  else:
196
  raise f'no method {method}'
197
+
198
  return texture, ori_trust_map > 1E-8
199
 
200
+
201
  def texture_inpaint(self, texture, mask):
202
 
203
  texture_np = self.render.uv_inpaint(texture, mask)
 
205
 
206
  return texture
207
 
208
+
209
  def recenter_image(self, image, border_ratio=0.2):
210
  if image.mode == 'RGB':
211
  return image
212
  elif image.mode == 'L':
213
+ return image.convert('RGB')
 
214
 
215
+ alpha = np.array(image)[:, :, 3]
216
+ non_zero = np.argwhere(alpha > 0)
217
+ if non_zero.size == 0:
218
+ raise ValueError("Image fully transparent")
219
 
220
+ min_row, min_col = non_zero.min(axis=0)
221
+ max_row, max_col = non_zero.max(axis=0)
222
 
223
+ cropped = image.crop((min_col, min_row, max_col + 1, max_row + 1))
224
 
225
+ w, h = cropped.size
226
+ bw = int(w * border_ratio)
227
+ bh = int(h * border_ratio)
228
 
229
+ new_w = w + 2 * bw
230
+ new_h = h + 2 * bh
231
+ sq = max(new_w, new_h)
232
 
233
+ new_img = Image.new('RGBA', (sq, sq), (255, 255, 255, 0))
234
+ new_img.paste(cropped, ((sq - new_w) // 2 + bw, (sq - new_h) // 2 + bh))
235
 
236
+ return new_img
237
 
 
 
 
 
 
238
 
239
  @torch.no_grad()
240
  def __call__(self, mesh, image):
 
246
 
247
  image_prompt = self.recenter_image(image_prompt)
248
 
249
+ # delight
250
  image_prompt = self.models['delight_model'](image_prompt)
251
 
252
  mesh = mesh_uv_wrap(mesh)
 
253
  self.render.load_mesh(mesh)
254
 
255
+ elevs = self.config.candidate_camera_elevs
256
+ azims = self.config.candidate_camera_azims
257
+ weights = self.config.candidate_view_weights
258
 
259
+ normal_maps = self.render_normal_multiview(elevs, azims)
260
+ position_maps = self.render_position_multiview(elevs, azims)
 
 
261
 
262
+ camera_info = [
263
+ (((azim // 30) + 9) % 12) //
264
+ {-20: 1, 0: 1, 20: 1, -90: 3, 90: 3}[elev] +
265
+ {-20: 0, 0: 12, 20: 24, -90: 36, 90: 40}[elev]
266
+ for azim, elev in zip(azims, elevs)
267
+ ]
268
+
269
+ multiviews = self.models['multiview_model'](
270
+ image_prompt, normal_maps + position_maps, camera_info
271
+ )
272
 
273
  for i in range(len(multiviews)):
 
274
  multiviews[i] = multiviews[i].resize(
275
+ (self.config.render_size, self.config.render_size)
276
+ )
277
 
278
+ texture, mask = self.bake_from_multiview(
279
+ multiviews, elevs, azims, weights, method=self.config.merge_method
280
+ )
281
 
282
  mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
283
 
284
  texture = self.texture_inpaint(texture, mask_np)
 
285
  self.render.set_texture(texture)
 
286
 
287
+ return self.render.save_mesh()