Spaces:
Runtime error
Runtime error
algohunt
commited on
Commit
·
415f10f
1
Parent(s):
029882a
app.py
CHANGED
|
@@ -42,6 +42,7 @@ from src.data import DemoData
|
|
| 42 |
from src.models import LiNo_UniPS
|
| 43 |
from torch.utils.data import DataLoader
|
| 44 |
import pytorch_lightning as pl
|
|
|
|
| 45 |
|
| 46 |
MAX_SEED = np.iinfo(np.int32).max
|
| 47 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
@@ -49,12 +50,8 @@ WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights'
|
|
| 49 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 50 |
os.makedirs(WEIGHTS_DIR, exist_ok=True)
|
| 51 |
|
| 52 |
-
is_gpu_available = torch.cuda.is_available()
|
| 53 |
|
| 54 |
|
| 55 |
-
if is_gpu_available:
|
| 56 |
-
print("✅ NVIDIA GPU。")
|
| 57 |
-
|
| 58 |
|
| 59 |
def cache_weights(weights_dir: str) -> dict:
|
| 60 |
import os
|
|
@@ -88,7 +85,7 @@ def preprocess_mesh(mesh_prompt):
|
|
| 88 |
trimesh_mesh = trimesh.load_mesh(mesh_prompt)
|
| 89 |
trimesh_mesh.export(mesh_prompt+'.glb')
|
| 90 |
return mesh_prompt+'.glb'
|
| 91 |
-
|
| 92 |
def generate_3d(image, seed=-1,
|
| 93 |
ss_guidance_strength=3, ss_sampling_steps=50,
|
| 94 |
slat_guidance_strength=3, slat_sampling_steps=6,normal_bridge=None):
|
|
@@ -136,7 +133,7 @@ def generate_3d(image, seed=-1,
|
|
| 136 |
trimesh_mesh.export(mesh_path)
|
| 137 |
|
| 138 |
return mesh_path, mesh_path
|
| 139 |
-
|
| 140 |
def predict_normal(input_images,input_mask):
|
| 141 |
test_dataset = DemoData(input_imgs_list=input_images,input_mask=input_mask)
|
| 142 |
test_loader = DataLoader(test_dataset, batch_size=1)
|
|
|
|
| 42 |
from src.models import LiNo_UniPS
|
| 43 |
from torch.utils.data import DataLoader
|
| 44 |
import pytorch_lightning as pl
|
| 45 |
+
import spaces
|
| 46 |
|
| 47 |
MAX_SEED = np.iinfo(np.int32).max
|
| 48 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
|
| 50 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 51 |
os.makedirs(WEIGHTS_DIR, exist_ok=True)
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def cache_weights(weights_dir: str) -> dict:
|
| 57 |
import os
|
|
|
|
| 85 |
trimesh_mesh = trimesh.load_mesh(mesh_prompt)
|
| 86 |
trimesh_mesh.export(mesh_prompt+'.glb')
|
| 87 |
return mesh_prompt+'.glb'
|
| 88 |
+
@spaces.GPU
|
| 89 |
def generate_3d(image, seed=-1,
|
| 90 |
ss_guidance_strength=3, ss_sampling_steps=50,
|
| 91 |
slat_guidance_strength=3, slat_sampling_steps=6,normal_bridge=None):
|
|
|
|
| 133 |
trimesh_mesh.export(mesh_path)
|
| 134 |
|
| 135 |
return mesh_path, mesh_path
|
| 136 |
+
@spaces.GPU
|
| 137 |
def predict_normal(input_images,input_mask):
|
| 138 |
test_dataset = DemoData(input_imgs_list=input_images,input_mask=input_mask)
|
| 139 |
test_loader = DataLoader(test_dataset, batch_size=1)
|