Shoraky commited on
Commit
4dbfe23
·
verified ·
1 Parent(s): 43f10b3

Prefer private runtime weights from assets repo

Browse files
Files changed (1) hide show
  1. api.py +7 -3
api.py CHANGED
@@ -107,7 +107,7 @@ def resolve_pipeline_root():
107
  revision=os.environ.get("SPORALIZE_ASSETS_REVISION"),
108
  token=get_hf_token(),
109
  local_dir=assets_dir,
110
- allow_patterns=["pipeline.py", "ViTPose/**", "Storage/**"],
111
  )
112
  seed_storage_if_needed(os.path.join(assets_dir, "Storage"), STORAGE_ROOT)
113
  return assets_dir
@@ -138,11 +138,15 @@ def load_pipeline_callable(pipeline_root: str):
138
  return run_pipeline
139
 
140
 
141
- def ensure_weight_file(spec: dict):
142
  override_path = os.environ.get(spec["override_env"])
143
  if override_path and os.path.isfile(override_path):
144
  return override_path
145
 
 
 
 
 
146
  local_fallback = spec.get("local_fallback")
147
  if local_fallback and os.path.isfile(local_fallback):
148
  return local_fallback
@@ -170,7 +174,7 @@ def ensure_runtime_ready(force: bool = False):
170
 
171
  pipeline_root = resolve_pipeline_root()
172
  run_pipeline = load_pipeline_callable(pipeline_root)
173
- weight_paths = {name: ensure_weight_file(spec) for name, spec in DEFAULT_WEIGHT_SPECS.items()}
174
 
175
  runtime_state.update({
176
  "ready": True,
 
107
  revision=os.environ.get("SPORALIZE_ASSETS_REVISION"),
108
  token=get_hf_token(),
109
  local_dir=assets_dir,
110
+ allow_patterns=["pipeline.py", "ViTPose/**", "Storage/**", "Weights/**"],
111
  )
112
  seed_storage_if_needed(os.path.join(assets_dir, "Storage"), STORAGE_ROOT)
113
  return assets_dir
 
138
  return run_pipeline
139
 
140
 
141
+ def ensure_weight_file(spec: dict, pipeline_root: str):
142
  override_path = os.environ.get(spec["override_env"])
143
  if override_path and os.path.isfile(override_path):
144
  return override_path
145
 
146
+ pipeline_weight = os.path.join(pipeline_root, "Weights", spec["filename"])
147
+ if os.path.isfile(pipeline_weight):
148
+ return pipeline_weight
149
+
150
  local_fallback = spec.get("local_fallback")
151
  if local_fallback and os.path.isfile(local_fallback):
152
  return local_fallback
 
174
 
175
  pipeline_root = resolve_pipeline_root()
176
  run_pipeline = load_pipeline_callable(pipeline_root)
177
+ weight_paths = {name: ensure_weight_file(spec, pipeline_root) for name, spec in DEFAULT_WEIGHT_SPECS.items()}
178
 
179
  runtime_state.update({
180
  "ready": True,