notenoughram commited on
Commit
067e49b
ยท
verified ยท
1 Parent(s): ab5bb1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -29,8 +29,8 @@ from trellis.pipelines import TrellisVGGTTo3DPipeline
29
  from trellis.representations import Gaussian, MeshExtractResult
30
  from trellis.utils import render_utils, postprocessing_utils
31
 
32
- # accelerate ๋กœ๋“œ (์œ„์—์„œ ์„ค์น˜ํ–ˆ์œผ๋ฏ€๋กœ ํ†ต๊ณผ๋จ)
33
- from accelerate import dispatch_model
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -87,7 +87,9 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
87
  }
88
 
89
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
 
90
  device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
 
91
  gs = Gaussian(
92
  aabb=state['gaussian']['aabb'],
93
  sh_degree=state['gaussian']['sh_degree'],
@@ -149,8 +151,10 @@ def generate_and_extract_glb(
149
  )
150
  except Exception as e:
151
  torch.cuda.empty_cache()
152
- raise e
 
153
 
 
154
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
155
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
156
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -222,7 +226,7 @@ demo = gr.Blocks(
222
  )
223
  with demo:
224
  gr.Markdown("""
225
- # ๐Ÿ’ป ReconViaGen (Multi-GPU Auto-Install)
226
  """)
227
 
228
  with gr.Row():
@@ -315,21 +319,24 @@ if __name__ == "__main__":
315
  print(f"โšก Detected {gpu_count} GPUs.")
316
 
317
  if gpu_count > 1:
318
- print("โšก Multi-GPU Mode Activated: Distributing model across all available GPUs.")
319
 
320
- pipeline.VGGT_model = dispatch_model(
321
- pipeline.VGGT_model,
322
- device_map="balanced"
323
- )
324
 
325
- pipeline.slat_model = dispatch_model(
326
- pipeline.slat_model,
327
- device_map="balanced"
328
- )
 
 
 
329
 
 
330
  pipeline.birefnet_model.to("cuda:0")
331
 
332
- print("โœ… Model dispatched successfully via Accelerate.")
333
 
334
  else:
335
  print("โš ๏ธ Warning: Only 1 GPU detected.")
 
29
  from trellis.representations import Gaussian, MeshExtractResult
30
  from trellis.utils import render_utils, postprocessing_utils
31
 
32
+ # [์ˆ˜์ •] infer_auto_device_map ์ถ”๊ฐ€ ์ž„ํฌํŠธ
33
+ from accelerate import dispatch_model, infer_auto_device_map
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
87
  }
88
 
89
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
90
+ # ๊ฒฐ๊ณผ ์ˆ˜์ง‘์šฉ ๋””๋ฐ”์ด์Šค
91
  device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
92
+
93
  gs = Gaussian(
94
  aabb=state['gaussian']['aabb'],
95
  sh_degree=state['gaussian']['sh_degree'],
 
151
  )
152
  except Exception as e:
153
  torch.cuda.empty_cache()
154
+ # ์—๋Ÿฌ ๋ฉ”์‹œ์ง€์— ๋ฉ”๋ชจ๋ฆฌ ํŒ ์ถ”๊ฐ€
155
+ raise RuntimeError(f"Generation Failed: {str(e)}\n(Try reducing image size or restart space)")
156
 
157
+ # ๋ Œ๋”๋ง์€ CPU ํ˜น์€ 0๋ฒˆ GPU์—์„œ ์ˆ˜ํ–‰ (๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ)
158
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
159
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
160
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
226
  )
227
  with demo:
228
  gr.Markdown("""
229
+ # ๐Ÿ’ป ReconViaGen (Fixed Multi-GPU)
230
  """)
231
 
232
  with gr.Row():
 
319
  print(f"โšก Detected {gpu_count} GPUs.")
320
 
321
  if gpu_count > 1:
322
+ print("โšก Multi-GPU Mode Activated.")
323
 
324
+ # [์ˆ˜์ •] infer_auto_device_map์œผ๋กœ ๋งต(Dictionary)์„ ๋จผ์ € ์ƒ์„ฑํ•ด์•ผ ํ•จ
325
+ # "balanced"๋Š” ๋ฌธ์ž์—ด์ด ์•„๋‹ˆ๋ผ ๋‚ด๋ถ€ ๋™์ž‘ ๋ฐฉ์‹์ด๋ฏ€๋กœ, infer_auto_device_map์„ ํ†ตํ•ด
326
+ # ์‹ค์ œ ๋ ˆ์ด์–ด๋ณ„ GPU ํ• ๋‹นํ‘œ(dict)๋ฅผ ๋ฐ›์•„์™€์•ผ dispatch_model์ด ์•Œ์•„๋จน์Šต๋‹ˆ๋‹ค.
 
327
 
328
+ print(" - Calculating Device Map for VGGT Model...")
329
+ vggt_map = infer_auto_device_map(pipeline.VGGT_model)
330
+ pipeline.VGGT_model = dispatch_model(pipeline.VGGT_model, device_map=vggt_map)
331
+
332
+ print(" - Calculating Device Map for SLAT Model...")
333
+ slat_map = infer_auto_device_map(pipeline.slat_model)
334
+ pipeline.slat_model = dispatch_model(pipeline.slat_model, device_map=slat_map)
335
 
336
+ # ๊ฐ€๋ฒผ์šด ๋ชจ๋ธ์€ 0๋ฒˆ ๊ณ ์ •
337
  pipeline.birefnet_model.to("cuda:0")
338
 
339
+ print("โœ… Models dispatched successfully.")
340
 
341
  else:
342
  print("โš ๏ธ Warning: Only 1 GPU detected.")