griddev commited on
Commit
fce3441
·
verified ·
1 Parent(s): d12499f

Deploy Streamlit Space app

Browse files
Files changed (1) hide show
  1. app.py +77 -34
app.py CHANGED
@@ -190,11 +190,21 @@ OUTPUT_ROOT = DEFAULT_OUTPUT_ROOT
190
 
191
 
192
  @st.cache_resource(show_spinner=False)
193
- def _download_weights(need_outputs: bool, need_shakespeare: bool) -> str:
 
 
 
 
194
  from huggingface_hub import snapshot_download
195
  allow_patterns = []
196
  if need_outputs:
197
- allow_patterns += ["outputs/*", "outputs/**/*"]
 
 
 
 
 
 
198
  if need_shakespeare:
199
  allow_patterns += ["input.txt", "shakespeare_transformer.pt"]
200
  if not allow_patterns:
@@ -210,39 +220,60 @@ def _download_weights(need_outputs: bool, need_shakespeare: bool) -> str:
210
 
211
  @st.cache_resource(show_spinner=False)
212
  def _download_model_outputs(model_dir: str) -> str:
213
- from huggingface_hub import snapshot_download
214
- allow_patterns = [
215
- f"outputs/{model_dir}/*",
216
- f"outputs/{model_dir}/**/*",
217
- ]
218
- return snapshot_download(
219
- repo_id=WEIGHTS_REPO_ID,
220
- repo_type="model",
221
- local_dir=WEIGHTS_CACHE_DIR,
222
- local_dir_use_symlinks=False,
223
- allow_patterns=allow_patterns,
224
  )
225
 
226
 
227
  def _ensure_model_outputs_available(model_dir: str) -> None:
228
- if not model_dir:
229
- return
230
- local = os.path.isdir(os.path.join(DEFAULT_OUTPUT_ROOT, model_dir))
231
- cached = os.path.isdir(os.path.join(WEIGHTS_CACHE_DIR, "outputs", model_dir))
232
- if local or cached:
233
- return
 
 
 
234
  try:
235
- _download_model_outputs(model_dir)
236
- except Exception as e:
237
- print(f"⚠️ Could not prefetch outputs for {model_dir}: {e}")
 
238
 
239
 
240
- def _resolve_weight_paths(need_outputs: bool, need_shakespeare: bool):
 
 
 
 
 
 
 
 
 
 
 
 
241
  output_root = DEFAULT_OUTPUT_ROOT
242
  shakespeare_file = DEFAULT_SHAKESPEARE_FILE
243
  shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
244
 
245
- have_outputs = os.path.isdir(output_root) and len(os.listdir(output_root)) > 0
 
 
 
 
 
 
 
 
 
 
 
 
246
  have_shakespeare = (
247
  os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights)
248
  )
@@ -250,7 +281,11 @@ def _resolve_weight_paths(need_outputs: bool, need_shakespeare: bool):
250
  return output_root, shakespeare_file, shakespeare_weights
251
 
252
  try:
253
- cache_dir = _download_weights(need_outputs, need_shakespeare)
 
 
 
 
254
  candidate_output_root = os.path.join(cache_dir, "outputs")
255
  candidate_shakespeare_file = os.path.join(cache_dir, "input.txt")
256
  candidate_shakespeare_weights = os.path.join(
@@ -291,7 +326,7 @@ def _has_finetuned(model_dir, subdir):
291
  for path in candidates:
292
  if os.path.isdir(path) and len(os.listdir(path)) > 0:
293
  return True
294
- return False
295
 
296
 
297
  def _ckpt_path(output_root, model_dir, subdir):
@@ -312,6 +347,7 @@ def _resolve_weight_source_for_model(model_name, requested_source):
312
  _resolve_weight_paths(
313
  need_outputs=True,
314
  need_shakespeare=(model_dir == "custom_vlm"),
 
315
  )
316
  if _has_finetuned(model_dir, requested_source):
317
  return requested_source, None
@@ -325,9 +361,6 @@ def _finetuned_available_for_model(model_name, requested_source):
325
  model_dir = MODEL_DIR.get(model_name)
326
  if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
327
  return False
328
- if _has_finetuned(model_dir, requested_source):
329
- return True
330
- _ensure_model_outputs_available(model_dir)
331
  return _has_finetuned(model_dir, requested_source)
332
 
333
 
@@ -346,7 +379,9 @@ def load_blip(weight_source="base"):
346
 
347
  if weight_source != "base":
348
  output_root, _, _ = _resolve_weight_paths(
349
- need_outputs=True, need_shakespeare=False
 
 
350
  )
351
  ckpt = _ckpt_path(output_root, "blip", weight_source)
352
  if os.path.isdir(ckpt) and os.listdir(ckpt):
@@ -375,7 +410,9 @@ def load_vit_gpt2(weight_source="base"):
375
 
376
  if weight_source != "base":
377
  output_root, _, _ = _resolve_weight_paths(
378
- need_outputs=True, need_shakespeare=False
 
 
379
  )
380
  ckpt = _ckpt_path(output_root, "vit_gpt2", weight_source)
381
  if os.path.isdir(ckpt) and os.listdir(ckpt):
@@ -400,7 +437,9 @@ def load_git(weight_source="base"):
400
 
401
  if weight_source != "base":
402
  output_root, _, _ = _resolve_weight_paths(
403
- need_outputs=True, need_shakespeare=False
 
 
404
  )
405
  ckpt = _ckpt_path(output_root, "git", weight_source)
406
  if os.path.isdir(ckpt) and os.listdir(ckpt):
@@ -422,7 +461,9 @@ def load_custom_vlm(weight_source="base"):
422
  device = get_device()
423
  cfg = CFG()
424
  output_root, shakespeare_file, shakespeare_weights = _resolve_weight_paths(
425
- need_outputs=(weight_source != "base"), need_shakespeare=True
 
 
426
  )
427
  cfg.output_root = output_root
428
  cfg.shakespeare_file = shakespeare_file
@@ -501,7 +542,9 @@ def load_blip_attention_model(weight_source="base"):
501
 
502
  if weight_source != "base":
503
  output_root, _, _ = _resolve_weight_paths(
504
- need_outputs=True, need_shakespeare=False
 
 
505
  )
506
  ckpt = _ckpt_path(output_root, "blip", weight_source)
507
  if os.path.isdir(ckpt) and os.listdir(ckpt):
 
190
 
191
 
192
  @st.cache_resource(show_spinner=False)
193
+ def _download_weights(
194
+ need_outputs: bool,
195
+ need_shakespeare: bool,
196
+ output_model_dir: str | None = None,
197
+ ) -> str:
198
  from huggingface_hub import snapshot_download
199
  allow_patterns = []
200
  if need_outputs:
201
+ if output_model_dir:
202
+ allow_patterns += [
203
+ f"outputs/{output_model_dir}/*",
204
+ f"outputs/{output_model_dir}/**/*",
205
+ ]
206
+ else:
207
+ allow_patterns += ["outputs/*", "outputs/**/*"]
208
  if need_shakespeare:
209
  allow_patterns += ["input.txt", "shakespeare_transformer.pt"]
210
  if not allow_patterns:
 
220
 
221
  @st.cache_resource(show_spinner=False)
222
  def _download_model_outputs(model_dir: str) -> str:
223
+ return _download_weights(
224
+ need_outputs=True,
225
+ need_shakespeare=False,
226
+ output_model_dir=model_dir,
 
 
 
 
 
 
 
227
  )
228
 
229
 
230
  def _ensure_model_outputs_available(model_dir: str) -> None:
231
+ # Intentionally no eager snapshot download here.
232
+ # We only fetch checkpoints when a user explicitly selects a fine-tuned weight.
233
+ _ = model_dir
234
+
235
+
236
+ @st.cache_data(show_spinner=False, ttl=900)
237
+ def _weights_repo_file_set() -> set[str]:
238
+ from huggingface_hub import HfApi
239
+ api = HfApi()
240
  try:
241
+ files = api.list_repo_files(repo_id=WEIGHTS_REPO_ID, repo_type="model")
242
+ return set(files)
243
+ except Exception:
244
+ return set()
245
 
246
 
247
+ def _remote_has_finetuned(model_dir: str, subdir: str) -> bool:
248
+ files = _weights_repo_file_set()
249
+ if not files:
250
+ return False
251
+ prefix = f"outputs/{model_dir}/{subdir}/"
252
+ return any(path.startswith(prefix) for path in files)
253
+
254
+
255
+ def _resolve_weight_paths(
256
+ need_outputs: bool,
257
+ need_shakespeare: bool,
258
+ output_model_dir: str | None = None,
259
+ ):
260
  output_root = DEFAULT_OUTPUT_ROOT
261
  shakespeare_file = DEFAULT_SHAKESPEARE_FILE
262
  shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
263
 
264
+ if need_outputs:
265
+ if output_model_dir:
266
+ local_model = os.path.join(output_root, output_model_dir)
267
+ cache_model = os.path.join(WEIGHTS_CACHE_DIR, "outputs", output_model_dir)
268
+ have_outputs = (
269
+ (os.path.isdir(local_model) and len(os.listdir(local_model)) > 0)
270
+ or (os.path.isdir(cache_model) and len(os.listdir(cache_model)) > 0)
271
+ )
272
+ else:
273
+ have_outputs = os.path.isdir(output_root) and len(os.listdir(output_root)) > 0
274
+ else:
275
+ have_outputs = True
276
+
277
  have_shakespeare = (
278
  os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights)
279
  )
 
281
  return output_root, shakespeare_file, shakespeare_weights
282
 
283
  try:
284
+ cache_dir = _download_weights(
285
+ need_outputs,
286
+ need_shakespeare,
287
+ output_model_dir=output_model_dir,
288
+ )
289
  candidate_output_root = os.path.join(cache_dir, "outputs")
290
  candidate_shakespeare_file = os.path.join(cache_dir, "input.txt")
291
  candidate_shakespeare_weights = os.path.join(
 
326
  for path in candidates:
327
  if os.path.isdir(path) and len(os.listdir(path)) > 0:
328
  return True
329
+ return _remote_has_finetuned(model_dir, subdir)
330
 
331
 
332
  def _ckpt_path(output_root, model_dir, subdir):
 
347
  _resolve_weight_paths(
348
  need_outputs=True,
349
  need_shakespeare=(model_dir == "custom_vlm"),
350
+ output_model_dir=model_dir,
351
  )
352
  if _has_finetuned(model_dir, requested_source):
353
  return requested_source, None
 
361
  model_dir = MODEL_DIR.get(model_name)
362
  if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
363
  return False
 
 
 
364
  return _has_finetuned(model_dir, requested_source)
365
 
366
 
 
379
 
380
  if weight_source != "base":
381
  output_root, _, _ = _resolve_weight_paths(
382
+ need_outputs=True,
383
+ need_shakespeare=False,
384
+ output_model_dir="blip",
385
  )
386
  ckpt = _ckpt_path(output_root, "blip", weight_source)
387
  if os.path.isdir(ckpt) and os.listdir(ckpt):
 
410
 
411
  if weight_source != "base":
412
  output_root, _, _ = _resolve_weight_paths(
413
+ need_outputs=True,
414
+ need_shakespeare=False,
415
+ output_model_dir="vit_gpt2",
416
  )
417
  ckpt = _ckpt_path(output_root, "vit_gpt2", weight_source)
418
  if os.path.isdir(ckpt) and os.listdir(ckpt):
 
437
 
438
  if weight_source != "base":
439
  output_root, _, _ = _resolve_weight_paths(
440
+ need_outputs=True,
441
+ need_shakespeare=False,
442
+ output_model_dir="git",
443
  )
444
  ckpt = _ckpt_path(output_root, "git", weight_source)
445
  if os.path.isdir(ckpt) and os.listdir(ckpt):
 
461
  device = get_device()
462
  cfg = CFG()
463
  output_root, shakespeare_file, shakespeare_weights = _resolve_weight_paths(
464
+ need_outputs=(weight_source != "base"),
465
+ need_shakespeare=True,
466
+ output_model_dir="custom_vlm" if weight_source != "base" else None,
467
  )
468
  cfg.output_root = output_root
469
  cfg.shakespeare_file = shakespeare_file
 
542
 
543
  if weight_source != "base":
544
  output_root, _, _ = _resolve_weight_paths(
545
+ need_outputs=True,
546
+ need_shakespeare=False,
547
+ output_model_dir="blip",
548
  )
549
  ckpt = _ckpt_path(output_root, "blip", weight_source)
550
  if os.path.isdir(ckpt) and os.listdir(ckpt):