Spaces:
Running on Zero
Running on Zero
Prefetch inference assets at startup
Browse files
app.py
CHANGED
|
@@ -89,6 +89,9 @@ from instruct_particulate.utils.inference_utils import (
|
|
| 89 |
from instruct_particulate.utils.inference_visualization_utils import (
|
| 90 |
save_predicted_point_query_rest_visualization,
|
| 91 |
)
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
REPO_ROOT = Path(__file__).resolve().parent
|
|
@@ -3164,6 +3167,7 @@ def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
|
|
| 3164 |
filename=CHECKPOINT_REPO_FILENAME,
|
| 3165 |
local_dir=str(checkpoint_path.parent),
|
| 3166 |
local_dir_use_symlinks=False,
|
|
|
|
| 3167 |
)
|
| 3168 |
downloaded_path = Path(downloaded_path)
|
| 3169 |
if downloaded_path != checkpoint_path:
|
|
@@ -3173,6 +3177,68 @@ def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
|
|
| 3173 |
return checkpoint_path
|
| 3174 |
|
| 3175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3176 |
def _spaces_gpu(fn):
|
| 3177 |
if spaces is None:
|
| 3178 |
return fn
|
|
@@ -3196,6 +3262,7 @@ class InstructParticulateApp:
|
|
| 3196 |
self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
|
| 3197 |
self.config = load_run_config(self.run_dir)
|
| 3198 |
configure_runtime_environment(self.config)
|
|
|
|
| 3199 |
self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
|
| 3200 |
resolve_inference_sampling_config(self.config)
|
| 3201 |
)
|
|
|
|
| 89 |
from instruct_particulate.utils.inference_visualization_utils import (
|
| 90 |
save_predicted_point_query_rest_visualization,
|
| 91 |
)
|
| 92 |
+
from instruct_particulate.utils.partfield_feature_utils import (
|
| 93 |
+
ensure_partfield_assets_downloaded,
|
| 94 |
+
)
|
| 95 |
|
| 96 |
|
| 97 |
REPO_ROOT = Path(__file__).resolve().parent
|
|
|
|
| 3167 |
filename=CHECKPOINT_REPO_FILENAME,
|
| 3168 |
local_dir=str(checkpoint_path.parent),
|
| 3169 |
local_dir_use_symlinks=False,
|
| 3170 |
+
token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
|
| 3171 |
)
|
| 3172 |
downloaded_path = Path(downloaded_path)
|
| 3173 |
if downloaded_path != checkpoint_path:
|
|
|
|
| 3177 |
return checkpoint_path
|
| 3178 |
|
| 3179 |
|
| 3180 |
+
def _hf_token() -> str | None:
|
| 3181 |
+
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 3182 |
+
|
| 3183 |
+
|
| 3184 |
+
def _prefetch_clip_text_assets(config: dict[str, Any]) -> None:
|
| 3185 |
+
model_config = config.get("model", {})
|
| 3186 |
+
if not isinstance(model_config, dict) or not bool(model_config.get("use_text_conditioning", True)):
|
| 3187 |
+
return
|
| 3188 |
+
|
| 3189 |
+
clip_model_name = str(model_config.get("clip_model_name", "openai/clip-vit-large-patch14"))
|
| 3190 |
+
print(f"Prefetching CLIP text assets: {clip_model_name}")
|
| 3191 |
+
try:
|
| 3192 |
+
from huggingface_hub import snapshot_download
|
| 3193 |
+
except ImportError as exc:
|
| 3194 |
+
raise ImportError("huggingface_hub is required to prefetch CLIP assets") from exc
|
| 3195 |
+
|
| 3196 |
+
snapshot_download(
|
| 3197 |
+
repo_id=clip_model_name,
|
| 3198 |
+
cache_dir=os.environ.get("HF_HOME") or None,
|
| 3199 |
+
allow_patterns=[
|
| 3200 |
+
"config.json",
|
| 3201 |
+
"tokenizer_config.json",
|
| 3202 |
+
"vocab.json",
|
| 3203 |
+
"merges.txt",
|
| 3204 |
+
"tokenizer.json",
|
| 3205 |
+
"special_tokens_map.json",
|
| 3206 |
+
"model.safetensors",
|
| 3207 |
+
],
|
| 3208 |
+
token=_hf_token(),
|
| 3209 |
+
)
|
| 3210 |
+
|
| 3211 |
+
|
| 3212 |
+
def _prefetch_partfield_assets(config: dict[str, Any]) -> None:
|
| 3213 |
+
model_config = config.get("model", {})
|
| 3214 |
+
if not isinstance(model_config, dict):
|
| 3215 |
+
return
|
| 3216 |
+
needs_partfield = any(
|
| 3217 |
+
bool(model_config.get(key, False))
|
| 3218 |
+
for key in (
|
| 3219 |
+
"use_pretrained_features_shape",
|
| 3220 |
+
"use_pretrained_features_query",
|
| 3221 |
+
"use_pretrained_features_point_prompt",
|
| 3222 |
+
)
|
| 3223 |
+
)
|
| 3224 |
+
if not needs_partfield:
|
| 3225 |
+
return
|
| 3226 |
+
print("Prefetching PartField checkpoint assets")
|
| 3227 |
+
ensure_partfield_assets_downloaded()
|
| 3228 |
+
|
| 3229 |
+
|
| 3230 |
+
def _prefetch_startup_assets(config: dict[str, Any]) -> None:
|
| 3231 |
+
if os.environ.get("INSTRUCT_PARTICULATE_PREFETCH_ASSETS", "1").strip().lower() in {
|
| 3232 |
+
"0",
|
| 3233 |
+
"false",
|
| 3234 |
+
"no",
|
| 3235 |
+
}:
|
| 3236 |
+
print("Skipping startup asset prefetch because INSTRUCT_PARTICULATE_PREFETCH_ASSETS is disabled")
|
| 3237 |
+
return
|
| 3238 |
+
_prefetch_partfield_assets(config)
|
| 3239 |
+
_prefetch_clip_text_assets(config)
|
| 3240 |
+
|
| 3241 |
+
|
| 3242 |
def _spaces_gpu(fn):
|
| 3243 |
if spaces is None:
|
| 3244 |
return fn
|
|
|
|
| 3262 |
self.checkpoint_path = _ensure_instruct_checkpoint(self.checkpoint_path)
|
| 3263 |
self.config = load_run_config(self.run_dir)
|
| 3264 |
configure_runtime_environment(self.config)
|
| 3265 |
+
_prefetch_startup_assets(self.config)
|
| 3266 |
self.num_shape_points, self.default_num_query_points, self.sharp_point_ratio = (
|
| 3267 |
resolve_inference_sampling_config(self.config)
|
| 3268 |
)
|
instruct_particulate/utils/partfield_feature_utils.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import argparse
|
| 4 |
import importlib
|
|
|
|
| 5 |
import sys
|
| 6 |
from contextlib import nullcontext
|
| 7 |
from pathlib import Path
|
|
@@ -45,6 +46,7 @@ def _ensure_partfield_checkpoint(checkpoint_path: Path) -> Path:
|
|
| 45 |
filename=_PARTFIELD_CHECKPOINT_FILENAME,
|
| 46 |
local_dir=str(checkpoint_path.parent),
|
| 47 |
local_dir_use_symlinks=False,
|
|
|
|
| 48 |
)
|
| 49 |
except ImportError:
|
| 50 |
pass
|
|
@@ -61,6 +63,12 @@ def _resolve_partfield_config_path() -> Path:
|
|
| 61 |
return _PARTFIELD_CONFIG_PATH
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
class PartFieldFeatureExtractor:
|
| 65 |
"""Lazy wrapper around the local PartField checkpoint used by `particulate`."""
|
| 66 |
|
|
|
|
| 2 |
|
| 3 |
import argparse
|
| 4 |
import importlib
|
| 5 |
+
import os
|
| 6 |
import sys
|
| 7 |
from contextlib import nullcontext
|
| 8 |
from pathlib import Path
|
|
|
|
| 46 |
filename=_PARTFIELD_CHECKPOINT_FILENAME,
|
| 47 |
local_dir=str(checkpoint_path.parent),
|
| 48 |
local_dir_use_symlinks=False,
|
| 49 |
+
token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
|
| 50 |
)
|
| 51 |
except ImportError:
|
| 52 |
pass
|
|
|
|
| 63 |
return _PARTFIELD_CONFIG_PATH
|
| 64 |
|
| 65 |
|
| 66 |
+
def ensure_partfield_assets_downloaded() -> Path:
|
| 67 |
+
_resolve_partfield_root()
|
| 68 |
+
_resolve_partfield_config_path()
|
| 69 |
+
return _ensure_partfield_checkpoint(_PARTFIELD_CHECKPOINT_PATH)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
class PartFieldFeatureExtractor:
|
| 73 |
"""Lazy wrapper around the local PartField checkpoint used by `particulate`."""
|
| 74 |
|