rayli commited on
Commit
212726c
·
verified ·
1 Parent(s): 365cb90

Prefetch inference assets at startup

Browse files
app.py CHANGED
@@ -89,6 +89,9 @@ from instruct_particulate.utils.inference_utils import (
89
  from instruct_particulate.utils.inference_visualization_utils import (
90
  save_predicted_point_query_rest_visualization,
91
  )
 
 
 
92
 
93
 
94
  REPO_ROOT = Path(__file__).resolve().parent
@@ -3164,6 +3167,7 @@ def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
3164
  filename=CHECKPOINT_REPO_FILENAME,
3165
  local_dir=str(checkpoint_path.parent),
3166
  local_dir_use_symlinks=False,
 
3167
  )
3168
  downloaded_path = Path(downloaded_path)
3169
  if downloaded_path != checkpoint_path:
@@ -3173,6 +3177,68 @@ def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
3173
  return checkpoint_path
3174
 
3175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3176
  def _spaces_gpu(fn):
3177
  if spaces is None:
3178
  return fn
@@ -3196,6 +3262,7 @@ class InstructParticulateApp:
3196
  self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
3197
  self.config = load_run_config(self.run_dir)
3198
  configure_runtime_environment(self.config)
 
3199
  self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
3200
  resolve_inference_sampling_config(self.config)
3201
  )
 
89
  from instruct_particulate.utils.inference_visualization_utils import (
90
  save_predicted_point_query_rest_visualization,
91
  )
92
+ from instruct_particulate.utils.partfield_feature_utils import (
93
+ ensure_partfield_assets_downloaded,
94
+ )
95
 
96
 
97
  REPO_ROOT = Path(__file__).resolve().parent
 
3167
  filename=CHECKPOINT_REPO_FILENAME,
3168
  local_dir=str(checkpoint_path.parent),
3169
  local_dir_use_symlinks=False,
3170
+ token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
3171
  )
3172
  downloaded_path = Path(downloaded_path)
3173
  if downloaded_path != checkpoint_path:
 
3177
  return checkpoint_path
3178
 
3179
 
3180
+ def _hf_token() -> str | None:
3181
+ return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
3182
+
3183
+
3184
+ def _prefetch_clip_text_assets(config: dict[str, Any]) -> None:
3185
+ model_config = config.get("model", {})
3186
+ if not isinstance(model_config, dict) or not bool(model_config.get("use_text_conditioning", True)):
3187
+ return
3188
+
3189
+ clip_model_name = str(model_config.get("clip_model_name", "openai/clip-vit-large-patch14"))
3190
+ print(f"Prefetching CLIP text assets: {clip_model_name}")
3191
+ try:
3192
+ from huggingface_hub import snapshot_download
3193
+ except ImportError as exc:
3194
+ raise ImportError("huggingface_hub is required to prefetch CLIP assets") from exc
3195
+
3196
+ snapshot_download(
3197
+ repo_id=clip_model_name,
3198
+ cache_dir=os.environ.get("HF_HOME") or None,
3199
+ allow_patterns=[
3200
+ "config.json",
3201
+ "tokenizer_config.json",
3202
+ "vocab.json",
3203
+ "merges.txt",
3204
+ "tokenizer.json",
3205
+ "special_tokens_map.json",
3206
+ "model.safetensors",
3207
+ ],
3208
+ token=_hf_token(),
3209
+ )
3210
+
3211
+
3212
+ def _prefetch_partfield_assets(config: dict[str, Any]) -> None:
3213
+ model_config = config.get("model", {})
3214
+ if not isinstance(model_config, dict):
3215
+ return
3216
+ needs_partfield = any(
3217
+ bool(model_config.get(key, False))
3218
+ for key in (
3219
+ "use_pretrained_features_shape",
3220
+ "use_pretrained_features_query",
3221
+ "use_pretrained_features_point_prompt",
3222
+ )
3223
+ )
3224
+ if not needs_partfield:
3225
+ return
3226
+ print("Prefetching PartField checkpoint assets")
3227
+ ensure_partfield_assets_downloaded()
3228
+
3229
+
3230
+ def _prefetch_startup_assets(config: dict[str, Any]) -> None:
3231
+ if os.environ.get("INSTRUCT_PARTICULATE_PREFETCH_ASSETS", "1").strip().lower() in {
3232
+ "0",
3233
+ "false",
3234
+ "no",
3235
+ }:
3236
+ print("Skipping startup asset prefetch because INSTRUCT_PARTICULATE_PREFETCH_ASSETS is disabled")
3237
+ return
3238
+ _prefetch_partfield_assets(config)
3239
+ _prefetch_clip_text_assets(config)
3240
+
3241
+
3242
  def _spaces_gpu(fn):
3243
  if spaces is None:
3244
  return fn
 
3262
  self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
3263
  self.config = load_run_config(self.run_dir)
3264
  configure_runtime_environment(self.config)
3265
+ _prefetch_startup_assets(self.config)
3266
  self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
3267
  resolve_inference_sampling_config(self.config)
3268
  )
instruct_particulate/utils/partfield_feature_utils.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import argparse
4
  import importlib
 
5
  import sys
6
  from contextlib import nullcontext
7
  from pathlib import Path
@@ -45,6 +46,7 @@ def _ensure_partfield_checkpoint(checkpoint_path: Path) -> Path:
45
  filename=_PARTFIELD_CHECKPOINT_FILENAME,
46
  local_dir=str(checkpoint_path.parent),
47
  local_dir_use_symlinks=False,
 
48
  )
49
  except ImportError:
50
  pass
@@ -61,6 +63,12 @@ def _resolve_partfield_config_path() -> Path:
61
  return _PARTFIELD_CONFIG_PATH
62
 
63
 
 
 
 
 
 
 
64
  class PartFieldFeatureExtractor:
65
  """Lazy wrapper around the local PartField checkpoint used by `particulate`."""
66
 
 
2
 
3
  import argparse
4
  import importlib
5
+ import os
6
  import sys
7
  from contextlib import nullcontext
8
  from pathlib import Path
 
46
  filename=_PARTFIELD_CHECKPOINT_FILENAME,
47
  local_dir=str(checkpoint_path.parent),
48
  local_dir_use_symlinks=False,
49
+ token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
50
  )
51
  except ImportError:
52
  pass
 
63
  return _PARTFIELD_CONFIG_PATH
64
 
65
 
66
+ def ensure_partfield_assets_downloaded() -> Path:
67
+ _resolve_partfield_root()
68
+ _resolve_partfield_config_path()
69
+ return _ensure_partfield_checkpoint(_PARTFIELD_CHECKPOINT_PATH)
70
+
71
+
72
  class PartFieldFeatureExtractor:
73
  """Lazy wrapper around the local PartField checkpoint used by `particulate`."""
74