LejobuildYT commited on
Commit
f83328a
·
verified ·
1 Parent(s): 9fa3d8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -207
app.py CHANGED
@@ -1,27 +1,28 @@
1
  import os
2
- import spaces
3
  import random
4
  import shutil
5
- import gradio as gr
6
  from glob import glob
7
  from pathlib import Path
8
- import uuid
9
  import argparse
 
10
  import torch
 
11
  import uvicorn
12
  from fastapi import FastAPI
13
  from fastapi.staticfiles import StaticFiles
 
14
  import trimesh
15
  from transformers import AutoProcessor, AutoModelForImageClassification
16
- from PIL import Image
17
 
 
18
  parser = argparse.ArgumentParser()
19
  parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
20
  parser.add_argument("--subfolder", type=str, default='hunyuan3d-dit-v2-mini-turbo')
21
  parser.add_argument("--texgen_model_path", type=str, default='tencent/Hunyuan3D-2')
22
  parser.add_argument('--port', type=int, default=7860)
23
  parser.add_argument('--host', type=str, default='0.0.0.0')
24
- parser.add_argument('--device', type=str, default='cuda')
25
  parser.add_argument('--mc_algo', type=str, default='mc')
26
  parser.add_argument('--cache_path', type=str, default='gradio_cache')
27
  parser.add_argument('--enable_t23d', action='store_true')
@@ -30,94 +31,80 @@ parser.add_argument('--enable_flashvdm', action='store_true')
30
  parser.add_argument('--compile', action='store_true')
31
  parser.add_argument('--low_vram_mode', action='store_true')
32
  args = parser.parse_args()
33
- args.enable_flashvdm = True
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  SAVE_DIR = args.cache_path
36
  os.makedirs(SAVE_DIR, exist_ok=True)
37
-
38
  CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
39
 
40
  HTML_HEIGHT = 500
41
  HTML_WIDTH = 500
 
42
 
43
- # -------------------- NSFW 检测模型加载 --------------------
44
  nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
45
- nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection").to(args.device)
46
- # -----------------------------------------------------------
47
-
48
 
 
49
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
50
  if randomize_seed:
51
  seed = random.randint(0, MAX_SEED)
52
  return seed
53
 
54
-
55
  def gen_save_folder(max_size=200):
56
  os.makedirs(SAVE_DIR, exist_ok=True)
57
-
58
- # 获取所有文件夹路径
59
  dirs = [f for f in Path(SAVE_DIR).iterdir() if f.is_dir()]
60
-
61
- # 如果文件夹数量超过 max_size,删除创建时间最久的文件夹
62
  if len(dirs) >= max_size:
63
- # 按创建时间排序,最久的排在前面
64
  oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
65
  shutil.rmtree(oldest_dir)
66
  print(f"Removed the oldest folder: {oldest_dir}")
67
-
68
- # 生成一个新的 uuid 文件夹名称
69
  new_folder = os.path.join(SAVE_DIR, str(uuid.uuid4()))
70
  os.makedirs(new_folder, exist_ok=True)
71
  print(f"Created new folder: {new_folder}")
72
-
73
  return new_folder
74
 
75
  def export_mesh(mesh, save_folder, textured=False, type='glb'):
76
- if textured:
77
- path = os.path.join(save_folder, f'textured_mesh.{type}')
78
- else:
79
- path = os.path.join(save_folder, f'white_mesh.{type}')
80
- if type not in ['glb', 'obj']:
81
- mesh.export(path)
82
- else:
83
- mesh.export(path, include_normals=textured)
84
  return path
85
 
86
  def build_model_viewer_html(save_folder, height=660, width=790, textured=False):
87
- # Remove first folder from path to make relative path
88
- if textured:
89
- related_path = f"./textured_mesh.glb"
90
- template_name = './assets/modelviewer-textured-template.html'
91
- output_html_path = os.path.join(save_folder, f'textured_mesh.html')
92
- else:
93
- related_path = f"./white_mesh.glb"
94
- template_name = './assets/modelviewer-template.html'
95
- output_html_path = os.path.join(save_folder, f'white_mesh.html')
96
-
97
- offset = 50 if textured else 10
98
  with open(os.path.join(CURRENT_DIR, template_name), 'r', encoding='utf-8') as f:
99
  template_html = f.read()
100
 
 
 
 
 
 
101
  with open(output_html_path, 'w', encoding='utf-8') as f:
102
- template_html = template_html.replace('#height#', f'{height - offset}')
103
- template_html = template_html.replace('#width#', f'{width}')
104
- template_html = template_html.replace('#src#', f'{related_path}/')
105
  f.write(template_html)
106
 
107
  rel_path = os.path.relpath(output_html_path, SAVE_DIR)
108
  iframe_tag = f'<iframe src="/static/{rel_path}" height="{height}" width="100%" frameborder="0"></iframe>'
109
- print(
110
- f'Find html file {output_html_path}, {os.path.exists(output_html_path)}, relative HTML path is /static/{rel_path}')
111
 
112
- return f"""
113
- <div style='height: {height}; width: 100%;'>
114
- {iframe_tag}
115
- </div>
116
- """
117
 
118
-
119
- from hy3dgen.shapegen import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier, \
120
- Hunyuan3DDiTFlowMatchingPipeline
121
  from hy3dgen.shapegen.pipelines import export_to_trimesh
122
  from hy3dgen.rembg import BackgroundRemover
123
 
@@ -126,8 +113,9 @@ i23d_worker = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
126
  args.model_path,
127
  subfolder=args.subfolder,
128
  use_safetensors=True,
129
- device=args.device,
130
  )
 
131
  if args.enable_flashvdm:
132
  mc_algo = 'mc' if args.device in ['cpu', 'mps'] else args.mc_algo
133
  i23d_worker.enable_flashvdm(mc_algo=mc_algo)
@@ -138,66 +126,39 @@ floater_remove_worker = FloaterRemover()
138
  degenerate_face_remove_worker = DegenerateFaceRemover()
139
  face_reduce_worker = FaceReducer()
140
 
141
-
142
  def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
143
- """Returns True if image is NSFW"""
144
- # inputs = nsfw_processor(images=image, return_tensors="pt").to(args.device)
145
- # with torch.no_grad():
146
- # outputs = nsfw_model(**inputs)
147
- # probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
148
- # nsfw_score = probs[0][1].item() # label 1 = NSFW
149
- nsfw_score = 0 # label 1 = NSFW
150
  return nsfw_score > threshold
151
 
 
 
152
 
153
-
154
- progress=gr.Progress()
155
-
156
- # @spaces.GPU(duration=40)
157
- def _gen_shape_on_gpu(
158
  image=None,
159
- steps=10, # 50
160
- guidance_scale=7.5, # 7.5
161
  seed=1234,
162
- octree_resolution=128, # 256
163
- num_chunks=50000, # 200000
164
- target_face_num=2500, # 10000
165
  randomize_seed: bool = False,
166
  ):
167
- progress(0,desc="Starting")
168
-
169
- def callback(step_idx, timestep, outputs):
170
- progress_value = ((step_idx+1.0)/steps)*(0.5/1.0)
171
- progress(progress_value, desc=f"Mesh generating, {step_idx + 1}/{steps} steps")
172
-
173
 
174
  if image is None:
175
- error_info = {
176
- "error": "Please provide either a caption or an image.",
177
- "status": "failed",
178
- }
179
- return None,None,None,None,error_info
180
-
181
- rgbImage = image.convert('RGB')
182
-
183
- # NSFW 检测
184
- if nsfw_model and nsfw_processor:
185
- if detect_nsfw(rgbImage):
186
- error_info = {
187
- "error": "The input image contains NSFW content and cannot be used. Please provide a different image and try again.",
188
- "status": "failed",
189
- }
190
- return None,None,None,None,error_info
191
 
192
  seed = int(randomize_seed_fn(seed, randomize_seed))
193
- octree_resolution = int(octree_resolution)
194
  save_folder = gen_save_folder()
195
- # 先移除背景
196
- image = rmbg_worker(rgbImage)
 
197
 
198
- # 生成模型
199
- generator = torch.Generator()
200
- generator = generator.manual_seed(int(seed))
201
  outputs = i23d_worker(
202
  image=image,
203
  num_inference_steps=steps,
@@ -206,94 +167,43 @@ def _gen_shape_on_gpu(
206
  octree_resolution=octree_resolution,
207
  num_chunks=num_chunks,
208
  output_type='mesh',
209
- callback=callback,
210
  callback_steps=1
211
  )
212
 
213
  mesh = export_to_trimesh(outputs)[0]
214
-
215
  path = export_mesh(mesh, save_folder, textured=False)
216
 
217
- # model_viewer_html = build_model_viewer_html(save_folder, height=HTML_HEIGHT, width=HTML_WIDTH)
218
-
219
- # return model_viewer_html, path
220
-
221
- if args.low_vram_mode:
222
- torch.cuda.empty_cache()
223
-
224
- if path is None:
225
- error_info = {
226
- "error": "'Please generate a mesh first.'",
227
- "status": "failed",
228
- }
229
- return None,None,None,None,error_info
230
-
231
- # 简化模型
232
- print(f'exporting {path}')
233
- print(f'reduce face to {target_face_num}')
234
-
235
  mesh = trimesh.load(path)
236
- progress(0.5,desc="Optimizing mesh")
237
-
238
  mesh = floater_remove_worker(mesh)
239
  mesh = degenerate_face_remove_worker(mesh)
240
- progress(0.6,desc="Reducing mesh faces")
241
  mesh = face_reduce_worker(mesh, target_face_num)
242
- save_folder = gen_save_folder()
243
 
244
- progress(0.9,desc="Converting format")
245
- file_type = "obj"
246
- sourceObjPath = export_mesh(mesh, save_folder, textured=False, type=file_type)
247
- rel_objPath = os.path.relpath(sourceObjPath, SAVE_DIR)
248
- objPath = "/static/"+rel_objPath
249
-
250
- # for preview
251
  save_folder = gen_save_folder()
252
- _ = export_mesh(mesh, save_folder, textured=False)
253
  model_viewer_html = build_model_viewer_html(save_folder, height=HTML_HEIGHT, width=HTML_WIDTH, textured=False)
254
 
 
 
 
 
 
255
 
256
- glbPath = os.path.join(save_folder, f'white_mesh.glb')
257
- rel_glbPath = os.path.relpath(glbPath, SAVE_DIR)
258
- glbPath = "/static/"+rel_glbPath
259
-
260
-
261
- progress(1,desc="Complete")
262
- info = {
263
- "status": "success"
264
- }
265
- return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath, info
266
 
267
-
268
- def gen_shape(
269
- image=None,
270
- steps=50,
271
- guidance_scale=7.5,
272
- seed=1234,
273
- octree_resolution=256,
274
- num_chunks=50000, # 2000000
275
- target_face_num=2500, # 10000
276
- randomize_seed: bool = False,
277
- ):
278
- # 调用 GPU 函数
279
- html_export_mesh,file_export,glbPath_output,objPath_output, info = _gen_shape_on_gpu(
280
- image,
281
- steps,
282
- guidance_scale,
283
- seed,
284
- octree_resolution,
285
- num_chunks,
286
- target_face_num,
287
- randomize_seed
288
- )
289
- # 如果出错,抛出异常
290
  if info["status"] == "failed":
291
  raise gr.Error(info["error"])
292
- return html_export_mesh, file_export, glbPath_output, objPath_output
293
-
 
294
  def get_example_img_list():
295
- print('Loading example img list ...')
296
  return sorted(glob('./assets/example_images/**/*.png', recursive=True))
 
297
  example_imgs = get_example_img_list()
298
 
299
  HTML_OUTPUT_PLACEHOLDER = f"""
@@ -303,72 +213,63 @@ HTML_OUTPUT_PLACEHOLDER = f"""
303
  </div>
304
  </div>
305
  """
306
- MAX_SEED = 1e7
307
 
 
308
  title = "## AI 3D Model Generator"
309
- description = "Our Image-to-3D Generator transforms your 2D photos into stunning, AI generated 3D models—ready for games, AR/VR, or 3D printing. Our AI 3D Modeling is based on Hunyuan 2.0. Check more in [imgto3d.ai](https://www.imgto3d.ai)."
310
 
311
  with gr.Blocks().queue() as demo:
312
  gr.Markdown(title)
313
  gr.Markdown(description)
314
  with gr.Row():
315
  with gr.Column(scale=3):
316
- gr.Markdown("#### Image Prompt")
317
  image = gr.Image(sources=["upload"], label='Image', type='pil', image_mode='RGBA', height=290)
318
  gen_button = gr.Button(value='Generate Shape', variant='primary')
319
  with gr.Accordion("Advanced Options", open=False):
320
- with gr.Column():
321
- seed = gr.Slider(
322
- label="Seed",
323
- minimum=0,
324
- maximum=MAX_SEED,
325
- step=1,
326
- value=1234,
327
- min_width=100,
328
- )
329
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
330
- with gr.Column():
331
- num_steps = gr.Slider(maximum=100, minimum=1, value=5, step=1, label='Inference Steps')
332
- octree_resolution = gr.Slider(maximum=512, minimum=16, value=128, label='Octree Resolution')
333
- with gr.Column():
334
- cfg_scale = gr.Slider(maximum=20.0, minimum=1.0, value=5.5, step=0.1, label='Guidance Scale')
335
- num_chunks = gr.Slider(maximum=50000, minimum=1000, value=2000, label='Number of Chunks') # old maximum=5000000
336
- target_face_num = gr.Slider(maximum=1000000, minimum=100, value=2500, label='Target Face Number') # old maximum=1000000
337
-
338
  with gr.Column(scale=6):
339
- gr.Markdown("#### Generated Mesh")
340
  html_export_mesh = gr.HTML(HTML_OUTPUT_PLACEHOLDER, label='Output')
341
  file_export = gr.DownloadButton(label="Download", variant='primary', interactive=False)
342
- with gr.Row():
343
- objPath_output = gr.Text(label="Obj Path",interactive=False)
344
- glbPath_output = gr.Text(label="Glb Path",interactive=False)
345
-
346
  with gr.Column(scale=3):
347
- gr.Markdown("#### Image Examples")
348
- gr.Examples(examples=example_imgs, inputs=[image],
349
- label=None, examples_per_page=18)
350
-
351
  gen_button.click(
352
  fn=gen_shape,
353
- inputs=[image,num_steps,cfg_scale,seed,octree_resolution,num_chunks,target_face_num, randomize_seed],
354
- outputs=[html_export_mesh,file_export, glbPath_output, objPath_output]
355
- )
356
 
 
357
  if __name__ == "__main__":
358
- # https://discuss.huggingface.co/t/how-to-serve-an-html-file/33921/2
359
- # create a FastAPI app
 
 
360
  app = FastAPI()
361
- # create a static directory to store the static files
362
  static_dir = Path(SAVE_DIR).absolute()
363
  static_dir.mkdir(parents=True, exist_ok=True)
364
  app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
365
  shutil.copytree('./assets/env_maps', os.path.join(static_dir, 'env_maps'), dirs_exist_ok=True)
366
 
367
- if args.low_vram_mode:
 
368
  torch.cuda.empty_cache()
369
 
 
 
 
 
 
 
370
  app = gr.mount_gradio_app(app, demo, path="/")
371
- # demo.launch()
372
- from spaces import zero
373
- zero.startup()
374
  uvicorn.run(app, host=args.host, port=args.port)
 
1
  import os
 
2
  import random
3
  import shutil
4
+ import uuid
5
  from glob import glob
6
  from pathlib import Path
 
7
  import argparse
8
+
9
  import torch
10
+ import gradio as gr
11
  import uvicorn
12
  from fastapi import FastAPI
13
  from fastapi.staticfiles import StaticFiles
14
+ from PIL import Image
15
  import trimesh
16
  from transformers import AutoProcessor, AutoModelForImageClassification
 
17
 
18
+ # -------------------- Argumente --------------------
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
21
  parser.add_argument("--subfolder", type=str, default='hunyuan3d-dit-v2-mini-turbo')
22
  parser.add_argument("--texgen_model_path", type=str, default='tencent/Hunyuan3D-2')
23
  parser.add_argument('--port', type=int, default=7860)
24
  parser.add_argument('--host', type=str, default='0.0.0.0')
25
+ parser.add_argument('--device', type=str, default=None)
26
  parser.add_argument('--mc_algo', type=str, default='mc')
27
  parser.add_argument('--cache_path', type=str, default='gradio_cache')
28
  parser.add_argument('--enable_t23d', action='store_true')
 
31
  parser.add_argument('--compile', action='store_true')
32
  parser.add_argument('--low_vram_mode', action='store_true')
33
  args = parser.parse_args()
 
34
 
35
+ # -------------------- Device Setup --------------------
36
+ if args.device is None:
37
+ if torch.cuda.is_available():
38
+ args.device = "cuda"
39
+ elif torch.backends.mps.is_available(): # macOS GPU
40
+ args.device = "mps"
41
+ else:
42
+ args.device = "cpu"
43
+
44
+ print(f"Using device: {args.device}")
45
+
46
+ # -------------------- Pfade --------------------
47
  SAVE_DIR = args.cache_path
48
  os.makedirs(SAVE_DIR, exist_ok=True)
 
49
  CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
50
 
51
  HTML_HEIGHT = 500
52
  HTML_WIDTH = 500
53
+ MAX_SEED = int(1e7)
54
 
55
+ # -------------------- NSFW Modell --------------------
56
  nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
57
+ nsfw_model = AutoModelForImageClassification.from_pretrained(
58
+ "Falconsai/nsfw_image_detection"
59
+ ).to(args.device)
60
 
61
+ # -------------------- Hilfsfunktionen --------------------
62
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
63
  if randomize_seed:
64
  seed = random.randint(0, MAX_SEED)
65
  return seed
66
 
 
67
  def gen_save_folder(max_size=200):
68
  os.makedirs(SAVE_DIR, exist_ok=True)
 
 
69
  dirs = [f for f in Path(SAVE_DIR).iterdir() if f.is_dir()]
 
 
70
  if len(dirs) >= max_size:
 
71
  oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
72
  shutil.rmtree(oldest_dir)
73
  print(f"Removed the oldest folder: {oldest_dir}")
 
 
74
  new_folder = os.path.join(SAVE_DIR, str(uuid.uuid4()))
75
  os.makedirs(new_folder, exist_ok=True)
76
  print(f"Created new folder: {new_folder}")
 
77
  return new_folder
78
 
79
  def export_mesh(mesh, save_folder, textured=False, type='glb'):
80
+ filename = f'textured_mesh.{type}' if textured else f'white_mesh.{type}'
81
+ path = os.path.join(save_folder, filename)
82
+ mesh.export(path, include_normals=textured)
 
 
 
 
 
83
  return path
84
 
85
  def build_model_viewer_html(save_folder, height=660, width=790, textured=False):
86
+ related_path = "textured_mesh.glb" if textured else "white_mesh.glb"
87
+ template_name = './assets/modelviewer-textured-template.html' if textured else './assets/modelviewer-template.html'
88
+ output_html_path = os.path.join(save_folder, f"{'textured' if textured else 'white'}_mesh.html")
89
+
 
 
 
 
 
 
 
90
  with open(os.path.join(CURRENT_DIR, template_name), 'r', encoding='utf-8') as f:
91
  template_html = f.read()
92
 
93
+ offset = 50 if textured else 10
94
+ template_html = template_html.replace('#height#', str(height - offset))
95
+ template_html = template_html.replace('#width#', str(width))
96
+ template_html = template_html.replace('#src#', f'./{related_path}/')
97
+
98
  with open(output_html_path, 'w', encoding='utf-8') as f:
 
 
 
99
  f.write(template_html)
100
 
101
  rel_path = os.path.relpath(output_html_path, SAVE_DIR)
102
  iframe_tag = f'<iframe src="/static/{rel_path}" height="{height}" width="100%" frameborder="0"></iframe>'
 
 
103
 
104
+ return f"<div style='height: {height}; width: 100%;'>{iframe_tag}</div>"
 
 
 
 
105
 
106
+ # -------------------- Hy3Dgen Worker --------------------
107
+ from hy3dgen.shapegen import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier, Hunyuan3DDiTFlowMatchingPipeline
 
108
  from hy3dgen.shapegen.pipelines import export_to_trimesh
109
  from hy3dgen.rembg import BackgroundRemover
110
 
 
113
  args.model_path,
114
  subfolder=args.subfolder,
115
  use_safetensors=True,
116
+ device=args.device
117
  )
118
+
119
  if args.enable_flashvdm:
120
  mc_algo = 'mc' if args.device in ['cpu', 'mps'] else args.mc_algo
121
  i23d_worker.enable_flashvdm(mc_algo=mc_algo)
 
126
  degenerate_face_remove_worker = DegenerateFaceRemover()
127
  face_reduce_worker = FaceReducer()
128
 
129
+ # -------------------- NSFW Detection --------------------
130
  def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
131
+ nsfw_score = 0 # Placeholder, optional: implement actual detection
 
 
 
 
 
 
132
  return nsfw_score > threshold
133
 
134
+ # -------------------- Mesh Generation --------------------
135
+ progress = gr.Progress()
136
 
137
+ def _gen_shape(
 
 
 
 
138
  image=None,
139
+ steps=10,
140
+ guidance_scale=7.5,
141
  seed=1234,
142
+ octree_resolution=128,
143
+ num_chunks=50000,
144
+ target_face_num=2500,
145
  randomize_seed: bool = False,
146
  ):
147
+ progress(0, desc="Starting")
 
 
 
 
 
148
 
149
  if image is None:
150
+ return None, None, None, None, {"error": "Please provide an image.", "status": "failed"}
151
+
152
+ rgb_image = image.convert('RGB')
153
+ if detect_nsfw(rgb_image):
154
+ return None, None, None, None, {"error": "NSFW content detected.", "status": "failed"}
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  seed = int(randomize_seed_fn(seed, randomize_seed))
 
157
  save_folder = gen_save_folder()
158
+ image = rmbg_worker(rgb_image)
159
+
160
+ generator = torch.Generator(device=args.device).manual_seed(seed)
161
 
 
 
 
162
  outputs = i23d_worker(
163
  image=image,
164
  num_inference_steps=steps,
 
167
  octree_resolution=octree_resolution,
168
  num_chunks=num_chunks,
169
  output_type='mesh',
170
+ callback=lambda step_idx, timestep, out: progress(((step_idx+1)/steps)*0.5, desc=f"Mesh generating {step_idx+1}/{steps}"),
171
  callback_steps=1
172
  )
173
 
174
  mesh = export_to_trimesh(outputs)[0]
 
175
  path = export_mesh(mesh, save_folder, textured=False)
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  mesh = trimesh.load(path)
178
+ progress(0.5, desc="Optimizing mesh")
 
179
  mesh = floater_remove_worker(mesh)
180
  mesh = degenerate_face_remove_worker(mesh)
181
+ progress(0.6, desc="Reducing faces")
182
  mesh = face_reduce_worker(mesh, target_face_num)
 
183
 
 
 
 
 
 
 
 
184
  save_folder = gen_save_folder()
185
+ source_obj_path = export_mesh(mesh, save_folder, textured=False, type="obj")
186
  model_viewer_html = build_model_viewer_html(save_folder, height=HTML_HEIGHT, width=HTML_WIDTH, textured=False)
187
 
188
+ glb_path = export_mesh(mesh, save_folder, textured=False, type="glb")
189
+ rel_glb_path = os.path.relpath(glb_path, SAVE_DIR)
190
+ glb_path = "/static/" + rel_glb_path
191
+ rel_obj_path = os.path.relpath(source_obj_path, SAVE_DIR)
192
+ obj_path = "/static/" + rel_obj_path
193
 
194
+ progress(1, desc="Complete")
195
+ return model_viewer_html, gr.update(value=source_obj_path, interactive=True), glb_path, obj_path, {"status": "success"}
 
 
 
 
 
 
 
 
196
 
197
+ def gen_shape(*args, **kwargs):
198
+ html, file_export, glb_path, obj_path, info = _gen_shape(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  if info["status"] == "failed":
200
  raise gr.Error(info["error"])
201
+ return html, file_export, glb_path, obj_path
202
+
203
+ # -------------------- Beispielbilder --------------------
204
  def get_example_img_list():
 
205
  return sorted(glob('./assets/example_images/**/*.png', recursive=True))
206
+
207
  example_imgs = get_example_img_list()
208
 
209
  HTML_OUTPUT_PLACEHOLDER = f"""
 
213
  </div>
214
  </div>
215
  """
 
216
 
217
+ # -------------------- Gradio UI --------------------
218
  title = "## AI 3D Model Generator"
219
+ description = "Transforms 2D photos into AI-generated 3D models."
220
 
221
  with gr.Blocks().queue() as demo:
222
  gr.Markdown(title)
223
  gr.Markdown(description)
224
  with gr.Row():
225
  with gr.Column(scale=3):
 
226
  image = gr.Image(sources=["upload"], label='Image', type='pil', image_mode='RGBA', height=290)
227
  gen_button = gr.Button(value='Generate Shape', variant='primary')
228
  with gr.Accordion("Advanced Options", open=False):
229
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1234)
230
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
231
+ num_steps = gr.Slider(maximum=100, minimum=1, value=5, step=1, label='Inference Steps')
232
+ octree_resolution = gr.Slider(maximum=512, minimum=16, value=128, label='Octree Resolution')
233
+ cfg_scale = gr.Slider(maximum=20.0, minimum=1.0, value=5.5, step=0.1, label='Guidance Scale')
234
+ num_chunks = gr.Slider(maximum=50000, minimum=1000, value=2000, label='Number of Chunks')
235
+ target_face_num = gr.Slider(maximum=1000000, minimum=100, value=2500, label='Target Face Number')
 
 
 
 
 
 
 
 
 
 
 
236
  with gr.Column(scale=6):
 
237
  html_export_mesh = gr.HTML(HTML_OUTPUT_PLACEHOLDER, label='Output')
238
  file_export = gr.DownloadButton(label="Download", variant='primary', interactive=False)
239
+ objPath_output = gr.Text(label="Obj Path", interactive=False)
240
+ glbPath_output = gr.Text(label="Glb Path", interactive=False)
 
 
241
  with gr.Column(scale=3):
242
+ gr.Examples(examples=example_imgs, inputs=[image], examples_per_page=18)
243
+
 
 
244
  gen_button.click(
245
  fn=gen_shape,
246
+ inputs=[image, num_steps, cfg_scale, seed, octree_resolution, num_chunks, target_face_num, randomize_seed],
247
+ outputs=[html_export_mesh, file_export, glbPath_output, objPath_output]
248
+ )
249
 
250
+ # -------------------- FastAPI + Gradio --------------------
251
  if __name__ == "__main__":
252
+ # Device Info
253
+ print(f"Using device: {args.device}")
254
+
255
+ # Optional: FastAPI static files (für Assets)
256
  app = FastAPI()
 
257
  static_dir = Path(SAVE_DIR).absolute()
258
  static_dir.mkdir(parents=True, exist_ok=True)
259
  app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
260
  shutil.copytree('./assets/env_maps', os.path.join(static_dir, 'env_maps'), dirs_exist_ok=True)
261
 
262
+ # Low VRAM cleanup
263
+ if args.low_vram_mode and args.device == "cuda":
264
  torch.cuda.empty_cache()
265
 
266
+ # Gradio Demo starten CPU-kompatibel, funktioniert auch in HF Spaces
267
+ demo.launch(
268
+ server_name="0.0.0.0", # für Spaces oder lokal
269
+ server_port=args.port,
270
+ share=True # erstellt einen öffentlichen Link wie HF Spaces
271
+ )
272
  app = gr.mount_gradio_app(app, demo, path="/")
273
+ # from spaces import zero
274
+ # zero.startup()
 
275
  uvicorn.run(app, host=args.host, port=args.port)