notenoughram commited on
Commit
613dce8
ยท
verified ยท
1 Parent(s): 98247dd

Update app_fine.py

Browse files
Files changed (1) hide show
  1. app_fine.py +26 -9
app_fine.py CHANGED
@@ -840,14 +840,31 @@ with demo:
840
  )
841
 
842
 
843
- # Launch the Gradio app
 
844
  if __name__ == "__main__":
845
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
846
- # pipeline = TrellisVGGTTo3DPipeline.from_pretrained("weights/trellis-vggt-v0-1")
847
- pipeline.cuda()
848
- pipeline.VGGT_model.cuda()
849
- pipeline.birefnet_model.cuda()
850
- pipeline.dreamsim_model.cuda()
851
- mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
852
- # mast3r_model = AsymmetricMASt3R.from_pretrained("weights/MAST3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth").cuda().eval()
853
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
  )
841
 
842
 
843
+ # ํŒŒ์ผ ์ตœํ•˜๋‹จ if __name__ == "__main__": ๋‚ด๋ถ€๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ˆ˜์ •ํ•˜์„ธ์š”.
844
+
845
  if __name__ == "__main__":
846
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
847
+
848
+ # --- ๋ฉ€ํ‹ฐ GPU ์„ค์ • ์‹œ์ž‘ ---
849
+ num_gpus = torch.cuda.device_count()
850
+ print(f"์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๊ฐœ์ˆ˜: {num_gpus}")
851
+
852
+ # ๋ฉ”์ธ ํŒŒ์ดํ”„๋ผ์ธ ๋ฐ ๋ชจ๋ธ๋“ค์„ ๋ฉ€ํ‹ฐ GPU์— ๋ถ„์‚ฐ
853
+ if num_gpus > 1:
854
+ pipeline.cuda() # ์šฐ์„  ๋ฉ”์ธ ๋กœ๋“œ
855
+ # ๋‚ด๋ถ€์˜ ํฐ ๋ชจ๋ธ๋“ค์„ DataParallel๋กœ ๊ฐ์‹ธ๊ธฐ
856
+ pipeline.VGGT_model = torch.nn.DataParallel(pipeline.VGGT_model)
857
+ pipeline.birefnet_model = torch.nn.DataParallel(pipeline.birefnet_model)
858
+ # pipeline.dreamsim_model = torch.nn.DataParallel(pipeline.dreamsim_model) # ํ•„์š”์‹œ ํ•ด์ œ
859
+ else:
860
+ pipeline.cuda()
861
+
862
+ # MAST3R ๋ชจ๋ธ ๋ฉ€ํ‹ฐ GPU ์„ค์ •
863
+ mast3r_raw = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric")
864
+ if num_gpus > 1:
865
+ mast3r_model = torch.nn.DataParallel(mast3r_raw).cuda().eval()
866
+ else:
867
+ mast3r_model = mast3r_raw.cuda().eval()
868
+ # --- ๋ฉ€ํ‹ฐ GPU ์„ค์ • ๋ ---
869
+
870
+ demo.launch()