Spaces:
Running on Zero
Running on Zero
Make Space ZeroGPU compatible
Browse files- app.py +46 -20
- requirements.txt +6 -5
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import json
|
|
| 5 |
import os
|
| 6 |
import shutil
|
| 7 |
import tempfile
|
|
|
|
| 8 |
import traceback
|
| 9 |
import zipfile
|
| 10 |
from datetime import datetime
|
|
@@ -16,6 +17,11 @@ import numpy as np
|
|
| 16 |
import torch
|
| 17 |
import torch.nn.functional as F
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from infer import (
|
| 20 |
_base_metadata,
|
| 21 |
_compute_motion_prediction_artifacts,
|
|
@@ -2573,6 +2579,12 @@ def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
|
|
| 2573 |
return checkpoint_path
|
| 2574 |
|
| 2575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2576 |
class InstructParticulateApp:
|
| 2577 |
def __init__(
|
| 2578 |
self,
|
|
@@ -2587,23 +2599,35 @@ class InstructParticulateApp:
|
|
| 2587 |
self.output_root.mkdir(parents=True, exist_ok=True)
|
| 2588 |
|
| 2589 |
self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
|
| 2590 |
-
|
| 2591 |
-
|
| 2592 |
-
self.device = torch.device("cuda")
|
| 2593 |
-
|
| 2594 |
-
config = load_run_config(self.run_dir)
|
| 2595 |
-
configure_runtime_environment(config)
|
| 2596 |
self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
|
| 2597 |
-
resolve_inference_sampling_config(config)
|
| 2598 |
)
|
| 2599 |
-
self.
|
| 2600 |
-
|
| 2601 |
-
|
| 2602 |
-
|
| 2603 |
-
|
| 2604 |
-
|
| 2605 |
-
|
| 2606 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2607 |
|
| 2608 |
def register_mesh(
|
| 2609 |
self,
|
|
@@ -2822,6 +2846,7 @@ class InstructParticulateApp:
|
|
| 2822 |
finally:
|
| 2823 |
torch.cuda.empty_cache()
|
| 2824 |
|
|
|
|
| 2825 |
def predict(
|
| 2826 |
self,
|
| 2827 |
mesh_path_value: Any,
|
|
@@ -2889,6 +2914,7 @@ class InstructParticulateApp:
|
|
| 2889 |
),
|
| 2890 |
)
|
| 2891 |
|
|
|
|
| 2892 |
mesh_geometry = self._prepare_geometry(mesh_path, canonical_up)
|
| 2893 |
segmentation_num_query_points = self._segmentation_num_query_points(
|
| 2894 |
args,
|
|
@@ -2903,7 +2929,7 @@ class InstructParticulateApp:
|
|
| 2903 |
sharp_point_ratio=self.sharp_point_ratio,
|
| 2904 |
link_names=link_names,
|
| 2905 |
joint_specs=joint_specs,
|
| 2906 |
-
device=
|
| 2907 |
)
|
| 2908 |
else:
|
| 2909 |
raw_prompt_points, raw_prompt_normals, has_point_prompt = point_prompt_arrays
|
|
@@ -2920,18 +2946,18 @@ class InstructParticulateApp:
|
|
| 2920 |
sharp_point_ratio=self.sharp_point_ratio,
|
| 2921 |
link_names=link_names,
|
| 2922 |
joint_specs=joint_specs,
|
| 2923 |
-
device=
|
| 2924 |
link_point_prompts=torch.from_numpy(normalized_prompt_points).float(),
|
| 2925 |
link_point_prompt_normals=torch.from_numpy(normalized_prompt_normals).float(),
|
| 2926 |
require_unique_link_names=False,
|
| 2927 |
)
|
| 2928 |
no_prompt_mask = torch.from_numpy(~has_point_prompt).bool().unsqueeze(0).to(
|
| 2929 |
-
|
| 2930 |
)
|
| 2931 |
batch["link_point_prompt_dropout_eligible"] = no_prompt_mask
|
| 2932 |
|
| 2933 |
output = run_batched_model_inference(
|
| 2934 |
-
|
| 2935 |
query_batch_size=int(args.query_batch_size),
|
| 2936 |
no_point_prompt_for_unique_text=bool(args.no_point_prompt),
|
| 2937 |
decode_joint_parameters=False,
|
|
@@ -3003,7 +3029,7 @@ class InstructParticulateApp:
|
|
| 3003 |
)
|
| 3004 |
motion_artifacts = _compute_motion_prediction_artifacts(
|
| 3005 |
args,
|
| 3006 |
-
model=
|
| 3007 |
batch=batch,
|
| 3008 |
normalized_mesh=mesh_geometry.normalized_mesh,
|
| 3009 |
face_part_ids=face_part_ids,
|
|
|
|
| 5 |
import os
|
| 6 |
import shutil
|
| 7 |
import tempfile
|
| 8 |
+
import threading
|
| 9 |
import traceback
|
| 10 |
import zipfile
|
| 11 |
from datetime import datetime
|
|
|
|
| 17 |
import torch
|
| 18 |
import torch.nn.functional as F
|
| 19 |
|
| 20 |
+
try:
|
| 21 |
+
import spaces
|
| 22 |
+
except ImportError:
|
| 23 |
+
spaces = None
|
| 24 |
+
|
| 25 |
from infer import (
|
| 26 |
_base_metadata,
|
| 27 |
_compute_motion_prediction_artifacts,
|
|
|
|
| 2579 |
return checkpoint_path
|
| 2580 |
|
| 2581 |
|
| 2582 |
+
def _spaces_gpu(fn):
|
| 2583 |
+
if spaces is None:
|
| 2584 |
+
return fn
|
| 2585 |
+
return spaces.GPU(duration=1800)(fn)
|
| 2586 |
+
|
| 2587 |
+
|
| 2588 |
class InstructParticulateApp:
|
| 2589 |
def __init__(
|
| 2590 |
self,
|
|
|
|
| 2599 |
self.output_root.mkdir(parents=True, exist_ok=True)
|
| 2600 |
|
| 2601 |
self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
|
| 2602 |
+
self.config = load_run_config(self.run_dir)
|
| 2603 |
+
configure_runtime_environment(self.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2604 |
self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
|
| 2605 |
+
resolve_inference_sampling_config(self.config)
|
| 2606 |
)
|
| 2607 |
+
self.device: torch.device | None = None
|
| 2608 |
+
self.model: Particulate2ArticulationModel | None = None
|
| 2609 |
+
self._model_lock = threading.Lock()
|
| 2610 |
+
|
| 2611 |
+
def _ensure_model_loaded(self) -> tuple[Particulate2ArticulationModel, torch.device]:
|
| 2612 |
+
if self.model is not None and self.device is not None:
|
| 2613 |
+
return self.model, self.device
|
| 2614 |
+
with self._model_lock:
|
| 2615 |
+
if self.model is not None and self.device is not None:
|
| 2616 |
+
return self.model, self.device
|
| 2617 |
+
if not torch.cuda.is_available():
|
| 2618 |
+
raise RuntimeError("CUDA is required for this demo but is not available.")
|
| 2619 |
+
device = torch.device("cuda")
|
| 2620 |
+
model = Particulate2ArticulationModel(**self.config["model"])
|
| 2621 |
+
load_model_checkpoint_for_inference(
|
| 2622 |
+
model,
|
| 2623 |
+
self.checkpoint_path,
|
| 2624 |
+
device=device,
|
| 2625 |
+
)
|
| 2626 |
+
if model.encoder.use_text_conditioning:
|
| 2627 |
+
model.encoder.compute_link_text_embeddings_on_the_fly = True
|
| 2628 |
+
self.model = model
|
| 2629 |
+
self.device = device
|
| 2630 |
+
return model, device
|
| 2631 |
|
| 2632 |
def register_mesh(
|
| 2633 |
self,
|
|
|
|
| 2846 |
finally:
|
| 2847 |
torch.cuda.empty_cache()
|
| 2848 |
|
| 2849 |
+
@_spaces_gpu
|
| 2850 |
def predict(
|
| 2851 |
self,
|
| 2852 |
mesh_path_value: Any,
|
|
|
|
| 2914 |
),
|
| 2915 |
)
|
| 2916 |
|
| 2917 |
+
model, device = self._ensure_model_loaded()
|
| 2918 |
mesh_geometry = self._prepare_geometry(mesh_path, canonical_up)
|
| 2919 |
segmentation_num_query_points = self._segmentation_num_query_points(
|
| 2920 |
args,
|
|
|
|
| 2929 |
sharp_point_ratio=self.sharp_point_ratio,
|
| 2930 |
link_names=link_names,
|
| 2931 |
joint_specs=joint_specs,
|
| 2932 |
+
device=device,
|
| 2933 |
)
|
| 2934 |
else:
|
| 2935 |
raw_prompt_points, raw_prompt_normals, has_point_prompt = point_prompt_arrays
|
|
|
|
| 2946 |
sharp_point_ratio=self.sharp_point_ratio,
|
| 2947 |
link_names=link_names,
|
| 2948 |
joint_specs=joint_specs,
|
| 2949 |
+
device=device,
|
| 2950 |
link_point_prompts=torch.from_numpy(normalized_prompt_points).float(),
|
| 2951 |
link_point_prompt_normals=torch.from_numpy(normalized_prompt_normals).float(),
|
| 2952 |
require_unique_link_names=False,
|
| 2953 |
)
|
| 2954 |
no_prompt_mask = torch.from_numpy(~has_point_prompt).bool().unsqueeze(0).to(
|
| 2955 |
+
device
|
| 2956 |
)
|
| 2957 |
batch["link_point_prompt_dropout_eligible"] = no_prompt_mask
|
| 2958 |
|
| 2959 |
output = run_batched_model_inference(
|
| 2960 |
+
model,
|
| 2961 |
query_batch_size=int(args.query_batch_size),
|
| 2962 |
no_point_prompt_for_unique_text=bool(args.no_point_prompt),
|
| 2963 |
decode_joint_parameters=False,
|
|
|
|
| 3029 |
)
|
| 3030 |
motion_artifacts = _compute_motion_prediction_artifacts(
|
| 3031 |
args,
|
| 3032 |
+
model=model,
|
| 3033 |
batch=batch,
|
| 3034 |
normalized_mesh=mesh_geometry.normalized_mesh,
|
| 3035 |
face_part_ids=face_part_ids,
|
requirements.txt
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
-
--extra-index-url https://download.pytorch.org/whl/
|
| 2 |
-
--find-links https://data.pyg.org/whl/torch-2.
|
| 3 |
|
| 4 |
-
torch==2.
|
| 5 |
-
torchvision==0.
|
| 6 |
-
torchaudio==2.
|
| 7 |
torch-scatter
|
| 8 |
|
| 9 |
gradio
|
|
|
|
| 10 |
google-genai
|
| 11 |
openai
|
| 12 |
transformers
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu128
|
| 2 |
+
--find-links https://data.pyg.org/whl/torch-2.8.0+cu128.html
|
| 3 |
|
| 4 |
+
torch==2.8.0
|
| 5 |
+
torchvision==0.23.0
|
| 6 |
+
torchaudio==2.8.0
|
| 7 |
torch-scatter
|
| 8 |
|
| 9 |
gradio
|
| 10 |
+
spaces
|
| 11 |
google-genai
|
| 12 |
openai
|
| 13 |
transformers
|