rayli commited on
Commit
1614db8
·
verified ·
1 Parent(s): 2dd4628

Make Space ZeroGPU compatible

Browse files
Files changed (2) hide show
  1. app.py +46 -20
  2. requirements.txt +6 -5
app.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import os
6
  import shutil
7
  import tempfile
 
8
  import traceback
9
  import zipfile
10
  from datetime import datetime
@@ -16,6 +17,11 @@ import numpy as np
16
  import torch
17
  import torch.nn.functional as F
18
 
 
 
 
 
 
19
  from infer import (
20
  _base_metadata,
21
  _compute_motion_prediction_artifacts,
@@ -2573,6 +2579,12 @@ def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
2573
  return checkpoint_path
2574
 
2575
 
 
 
 
 
 
 
2576
  class InstructParticulateApp:
2577
  def __init__(
2578
  self,
@@ -2587,23 +2599,35 @@ class InstructParticulateApp:
2587
  self.output_root.mkdir(parents=True, exist_ok=True)
2588
 
2589
  self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
2590
- if not torch.cuda.is_available():
2591
- raise RuntimeError("CUDA is required for this demo but is not available.")
2592
- self.device = torch.device("cuda")
2593
-
2594
- config = load_run_config(self.run_dir)
2595
- configure_runtime_environment(config)
2596
  self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
2597
- resolve_inference_sampling_config(config)
2598
  )
2599
- self.model = Particulate2ArticulationModel(**config["model"])
2600
- load_model_checkpoint_for_inference(
2601
- self.model,
2602
- self.checkpoint_path,
2603
- device=self.device,
2604
- )
2605
- if self.model.encoder.use_text_conditioning:
2606
- self.model.encoder.compute_link_text_embeddings_on_the_fly = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2607
 
2608
  def register_mesh(
2609
  self,
@@ -2822,6 +2846,7 @@ class InstructParticulateApp:
2822
  finally:
2823
  torch.cuda.empty_cache()
2824
 
 
2825
  def predict(
2826
  self,
2827
  mesh_path_value: Any,
@@ -2889,6 +2914,7 @@ class InstructParticulateApp:
2889
  ),
2890
  )
2891
 
 
2892
  mesh_geometry = self._prepare_geometry(mesh_path, canonical_up)
2893
  segmentation_num_query_points = self._segmentation_num_query_points(
2894
  args,
@@ -2903,7 +2929,7 @@ class InstructParticulateApp:
2903
  sharp_point_ratio=self.sharp_point_ratio,
2904
  link_names=link_names,
2905
  joint_specs=joint_specs,
2906
- device=self.device,
2907
  )
2908
  else:
2909
  raw_prompt_points, raw_prompt_normals, has_point_prompt = point_prompt_arrays
@@ -2920,18 +2946,18 @@ class InstructParticulateApp:
2920
  sharp_point_ratio=self.sharp_point_ratio,
2921
  link_names=link_names,
2922
  joint_specs=joint_specs,
2923
- device=self.device,
2924
  link_point_prompts=torch.from_numpy(normalized_prompt_points).float(),
2925
  link_point_prompt_normals=torch.from_numpy(normalized_prompt_normals).float(),
2926
  require_unique_link_names=False,
2927
  )
2928
  no_prompt_mask = torch.from_numpy(~has_point_prompt).bool().unsqueeze(0).to(
2929
- self.device
2930
  )
2931
  batch["link_point_prompt_dropout_eligible"] = no_prompt_mask
2932
 
2933
  output = run_batched_model_inference(
2934
- self.model,
2935
  query_batch_size=int(args.query_batch_size),
2936
  no_point_prompt_for_unique_text=bool(args.no_point_prompt),
2937
  decode_joint_parameters=False,
@@ -3003,7 +3029,7 @@ class InstructParticulateApp:
3003
  )
3004
  motion_artifacts = _compute_motion_prediction_artifacts(
3005
  args,
3006
- model=self.model,
3007
  batch=batch,
3008
  normalized_mesh=mesh_geometry.normalized_mesh,
3009
  face_part_ids=face_part_ids,
 
5
  import os
6
  import shutil
7
  import tempfile
8
+ import threading
9
  import traceback
10
  import zipfile
11
  from datetime import datetime
 
17
  import torch
18
  import torch.nn.functional as F
19
 
20
+ try:
21
+ import spaces
22
+ except ImportError:
23
+ spaces = None
24
+
25
  from infer import (
26
  _base_metadata,
27
  _compute_motion_prediction_artifacts,
 
2579
  return checkpoint_path
2580
 
2581
 
2582
+ def _spaces_gpu(fn):
2583
+ if spaces is None:
2584
+ return fn
2585
+ return spaces.GPU(duration=1800)(fn)
2586
+
2587
+
2588
  class InstructParticulateApp:
2589
  def __init__(
2590
  self,
 
2599
  self.output_root.mkdir(parents=True, exist_ok=True)
2600
 
2601
  self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
2602
+ self.config = load_run_config(self.run_dir)
2603
+ configure_runtime_environment(self.config)
 
 
 
 
2604
  self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
2605
+ resolve_inference_sampling_config(self.config)
2606
  )
2607
+ self.device: torch.device | None = None
2608
+ self.model: Particulate2ArticulationModel | None = None
2609
+ self._model_lock = threading.Lock()
2610
+
2611
+ def _ensure_model_loaded(self) -> tuple[Particulate2ArticulationModel, torch.device]:
2612
+ if self.model is not None and self.device is not None:
2613
+ return self.model, self.device
2614
+ with self._model_lock:
2615
+ if self.model is not None and self.device is not None:
2616
+ return self.model, self.device
2617
+ if not torch.cuda.is_available():
2618
+ raise RuntimeError("CUDA is required for this demo but is not available.")
2619
+ device = torch.device("cuda")
2620
+ model = Particulate2ArticulationModel(**self.config["model"])
2621
+ load_model_checkpoint_for_inference(
2622
+ model,
2623
+ self.checkpoint_path,
2624
+ device=device,
2625
+ )
2626
+ if model.encoder.use_text_conditioning:
2627
+ model.encoder.compute_link_text_embeddings_on_the_fly = True
2628
+ self.model = model
2629
+ self.device = device
2630
+ return model, device
2631
 
2632
  def register_mesh(
2633
  self,
 
2846
  finally:
2847
  torch.cuda.empty_cache()
2848
 
2849
+ @_spaces_gpu
2850
  def predict(
2851
  self,
2852
  mesh_path_value: Any,
 
2914
  ),
2915
  )
2916
 
2917
+ model, device = self._ensure_model_loaded()
2918
  mesh_geometry = self._prepare_geometry(mesh_path, canonical_up)
2919
  segmentation_num_query_points = self._segmentation_num_query_points(
2920
  args,
 
2929
  sharp_point_ratio=self.sharp_point_ratio,
2930
  link_names=link_names,
2931
  joint_specs=joint_specs,
2932
+ device=device,
2933
  )
2934
  else:
2935
  raw_prompt_points, raw_prompt_normals, has_point_prompt = point_prompt_arrays
 
2946
  sharp_point_ratio=self.sharp_point_ratio,
2947
  link_names=link_names,
2948
  joint_specs=joint_specs,
2949
+ device=device,
2950
  link_point_prompts=torch.from_numpy(normalized_prompt_points).float(),
2951
  link_point_prompt_normals=torch.from_numpy(normalized_prompt_normals).float(),
2952
  require_unique_link_names=False,
2953
  )
2954
  no_prompt_mask = torch.from_numpy(~has_point_prompt).bool().unsqueeze(0).to(
2955
+ device
2956
  )
2957
  batch["link_point_prompt_dropout_eligible"] = no_prompt_mask
2958
 
2959
  output = run_batched_model_inference(
2960
+ model,
2961
  query_batch_size=int(args.query_batch_size),
2962
  no_point_prompt_for_unique_text=bool(args.no_point_prompt),
2963
  decode_joint_parameters=False,
 
3029
  )
3030
  motion_artifacts = _compute_motion_prediction_artifacts(
3031
  args,
3032
+ model=model,
3033
  batch=batch,
3034
  normalized_mesh=mesh_geometry.normalized_mesh,
3035
  face_part_ids=face_part_ids,
requirements.txt CHANGED
@@ -1,12 +1,13 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu124
2
- --find-links https://data.pyg.org/whl/torch-2.4.0+cu124.html
3
 
4
- torch==2.4.0
5
- torchvision==0.19.0
6
- torchaudio==2.4.0
7
  torch-scatter
8
 
9
  gradio
 
10
  google-genai
11
  openai
12
  transformers
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+ --find-links https://data.pyg.org/whl/torch-2.8.0+cu128.html
3
 
4
+ torch==2.8.0
5
+ torchvision==0.23.0
6
+ torchaudio==2.8.0
7
  torch-scatter
8
 
9
  gradio
10
+ spaces
11
  google-genai
12
  openai
13
  transformers