Spaces:
Paused
Paused
Update app_fine.py
Browse files- app_fine.py +26 -9
app_fine.py
CHANGED
|
@@ -840,14 +840,31 @@ with demo:
|
|
| 840 |
)
|
| 841 |
|
| 842 |
|
| 843 |
-
#
|
|
|
|
| 844 |
if __name__ == "__main__":
|
| 845 |
pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
)
|
| 841 |
|
| 842 |
|
| 843 |
+
# ํ์ผ ์ตํ๋จ if __name__ == "__main__": ๋ด๋ถ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์์ ํ์ธ์.
|
| 844 |
+
|
| 845 |
if __name__ == "__main__":
|
| 846 |
pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
|
| 847 |
+
|
| 848 |
+
# --- ๋ฉํฐ GPU ์ค์ ์์ ---
|
| 849 |
+
num_gpus = torch.cuda.device_count()
|
| 850 |
+
print(f"์ฌ์ฉ ๊ฐ๋ฅํ GPU ๊ฐ์: {num_gpus}")
|
| 851 |
+
|
| 852 |
+
# ๋ฉ์ธ ํ์ดํ๋ผ์ธ ๋ฐ ๋ชจ๋ธ๋ค์ ๋ฉํฐ GPU์ ๋ถ์ฐ
|
| 853 |
+
if num_gpus > 1:
|
| 854 |
+
pipeline.cuda() # ์ฐ์ ๋ฉ์ธ ๋ก๋
|
| 855 |
+
# ๋ด๋ถ์ ํฐ ๋ชจ๋ธ๋ค์ DataParallel๋ก ๊ฐ์ธ๊ธฐ
|
| 856 |
+
pipeline.VGGT_model = torch.nn.DataParallel(pipeline.VGGT_model)
|
| 857 |
+
pipeline.birefnet_model = torch.nn.DataParallel(pipeline.birefnet_model)
|
| 858 |
+
# pipeline.dreamsim_model = torch.nn.DataParallel(pipeline.dreamsim_model) # ํ์์ ํด์
|
| 859 |
+
else:
|
| 860 |
+
pipeline.cuda()
|
| 861 |
+
|
| 862 |
+
# MAST3R ๋ชจ๋ธ ๋ฉํฐ GPU ์ค์
|
| 863 |
+
mast3r_raw = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric")
|
| 864 |
+
if num_gpus > 1:
|
| 865 |
+
mast3r_model = torch.nn.DataParallel(mast3r_raw).cuda().eval()
|
| 866 |
+
else:
|
| 867 |
+
mast3r_model = mast3r_raw.cuda().eval()
|
| 868 |
+
# --- ๋ฉํฐ GPU ์ค์ ๋ ---
|
| 869 |
+
|
| 870 |
+
demo.launch()
|