Spaces:
Sleeping
Sleeping
Deploy Streamlit Space app
Browse files
app.py
CHANGED
|
@@ -190,11 +190,21 @@ OUTPUT_ROOT = DEFAULT_OUTPUT_ROOT
|
|
| 190 |
|
| 191 |
|
| 192 |
@st.cache_resource(show_spinner=False)
|
| 193 |
-
def _download_weights(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
from huggingface_hub import snapshot_download
|
| 195 |
allow_patterns = []
|
| 196 |
if need_outputs:
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
if need_shakespeare:
|
| 199 |
allow_patterns += ["input.txt", "shakespeare_transformer.pt"]
|
| 200 |
if not allow_patterns:
|
|
@@ -210,39 +220,60 @@ def _download_weights(need_outputs: bool, need_shakespeare: bool) -> str:
|
|
| 210 |
|
| 211 |
@st.cache_resource(show_spinner=False)
|
| 212 |
def _download_model_outputs(model_dir: str) -> str:
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
]
|
| 218 |
-
return snapshot_download(
|
| 219 |
-
repo_id=WEIGHTS_REPO_ID,
|
| 220 |
-
repo_type="model",
|
| 221 |
-
local_dir=WEIGHTS_CACHE_DIR,
|
| 222 |
-
local_dir_use_symlinks=False,
|
| 223 |
-
allow_patterns=allow_patterns,
|
| 224 |
)
|
| 225 |
|
| 226 |
|
| 227 |
def _ensure_model_outputs_available(model_dir: str) -> None:
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
| 234 |
try:
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
output_root = DEFAULT_OUTPUT_ROOT
|
| 242 |
shakespeare_file = DEFAULT_SHAKESPEARE_FILE
|
| 243 |
shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
|
| 244 |
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
have_shakespeare = (
|
| 247 |
os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights)
|
| 248 |
)
|
|
@@ -250,7 +281,11 @@ def _resolve_weight_paths(need_outputs: bool, need_shakespeare: bool):
|
|
| 250 |
return output_root, shakespeare_file, shakespeare_weights
|
| 251 |
|
| 252 |
try:
|
| 253 |
-
cache_dir = _download_weights(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
candidate_output_root = os.path.join(cache_dir, "outputs")
|
| 255 |
candidate_shakespeare_file = os.path.join(cache_dir, "input.txt")
|
| 256 |
candidate_shakespeare_weights = os.path.join(
|
|
@@ -291,7 +326,7 @@ def _has_finetuned(model_dir, subdir):
|
|
| 291 |
for path in candidates:
|
| 292 |
if os.path.isdir(path) and len(os.listdir(path)) > 0:
|
| 293 |
return True
|
| 294 |
-
return
|
| 295 |
|
| 296 |
|
| 297 |
def _ckpt_path(output_root, model_dir, subdir):
|
|
@@ -312,6 +347,7 @@ def _resolve_weight_source_for_model(model_name, requested_source):
|
|
| 312 |
_resolve_weight_paths(
|
| 313 |
need_outputs=True,
|
| 314 |
need_shakespeare=(model_dir == "custom_vlm"),
|
|
|
|
| 315 |
)
|
| 316 |
if _has_finetuned(model_dir, requested_source):
|
| 317 |
return requested_source, None
|
|
@@ -325,9 +361,6 @@ def _finetuned_available_for_model(model_name, requested_source):
|
|
| 325 |
model_dir = MODEL_DIR.get(model_name)
|
| 326 |
if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
|
| 327 |
return False
|
| 328 |
-
if _has_finetuned(model_dir, requested_source):
|
| 329 |
-
return True
|
| 330 |
-
_ensure_model_outputs_available(model_dir)
|
| 331 |
return _has_finetuned(model_dir, requested_source)
|
| 332 |
|
| 333 |
|
|
@@ -346,7 +379,9 @@ def load_blip(weight_source="base"):
|
|
| 346 |
|
| 347 |
if weight_source != "base":
|
| 348 |
output_root, _, _ = _resolve_weight_paths(
|
| 349 |
-
need_outputs=True,
|
|
|
|
|
|
|
| 350 |
)
|
| 351 |
ckpt = _ckpt_path(output_root, "blip", weight_source)
|
| 352 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
@@ -375,7 +410,9 @@ def load_vit_gpt2(weight_source="base"):
|
|
| 375 |
|
| 376 |
if weight_source != "base":
|
| 377 |
output_root, _, _ = _resolve_weight_paths(
|
| 378 |
-
need_outputs=True,
|
|
|
|
|
|
|
| 379 |
)
|
| 380 |
ckpt = _ckpt_path(output_root, "vit_gpt2", weight_source)
|
| 381 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
@@ -400,7 +437,9 @@ def load_git(weight_source="base"):
|
|
| 400 |
|
| 401 |
if weight_source != "base":
|
| 402 |
output_root, _, _ = _resolve_weight_paths(
|
| 403 |
-
need_outputs=True,
|
|
|
|
|
|
|
| 404 |
)
|
| 405 |
ckpt = _ckpt_path(output_root, "git", weight_source)
|
| 406 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
@@ -422,7 +461,9 @@ def load_custom_vlm(weight_source="base"):
|
|
| 422 |
device = get_device()
|
| 423 |
cfg = CFG()
|
| 424 |
output_root, shakespeare_file, shakespeare_weights = _resolve_weight_paths(
|
| 425 |
-
need_outputs=(weight_source != "base"),
|
|
|
|
|
|
|
| 426 |
)
|
| 427 |
cfg.output_root = output_root
|
| 428 |
cfg.shakespeare_file = shakespeare_file
|
|
@@ -501,7 +542,9 @@ def load_blip_attention_model(weight_source="base"):
|
|
| 501 |
|
| 502 |
if weight_source != "base":
|
| 503 |
output_root, _, _ = _resolve_weight_paths(
|
| 504 |
-
need_outputs=True,
|
|
|
|
|
|
|
| 505 |
)
|
| 506 |
ckpt = _ckpt_path(output_root, "blip", weight_source)
|
| 507 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
@st.cache_resource(show_spinner=False)
|
| 193 |
+
def _download_weights(
|
| 194 |
+
need_outputs: bool,
|
| 195 |
+
need_shakespeare: bool,
|
| 196 |
+
output_model_dir: str | None = None,
|
| 197 |
+
) -> str:
|
| 198 |
from huggingface_hub import snapshot_download
|
| 199 |
allow_patterns = []
|
| 200 |
if need_outputs:
|
| 201 |
+
if output_model_dir:
|
| 202 |
+
allow_patterns += [
|
| 203 |
+
f"outputs/{output_model_dir}/*",
|
| 204 |
+
f"outputs/{output_model_dir}/**/*",
|
| 205 |
+
]
|
| 206 |
+
else:
|
| 207 |
+
allow_patterns += ["outputs/*", "outputs/**/*"]
|
| 208 |
if need_shakespeare:
|
| 209 |
allow_patterns += ["input.txt", "shakespeare_transformer.pt"]
|
| 210 |
if not allow_patterns:
|
|
|
|
| 220 |
|
| 221 |
@st.cache_resource(show_spinner=False)
|
| 222 |
def _download_model_outputs(model_dir: str) -> str:
|
| 223 |
+
return _download_weights(
|
| 224 |
+
need_outputs=True,
|
| 225 |
+
need_shakespeare=False,
|
| 226 |
+
output_model_dir=model_dir,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
)
|
| 228 |
|
| 229 |
|
| 230 |
def _ensure_model_outputs_available(model_dir: str) -> None:
|
| 231 |
+
# Intentionally no eager snapshot download here.
|
| 232 |
+
# We only fetch checkpoints when a user explicitly selects a fine-tuned weight.
|
| 233 |
+
_ = model_dir
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@st.cache_data(show_spinner=False, ttl=900)
|
| 237 |
+
def _weights_repo_file_set() -> set[str]:
|
| 238 |
+
from huggingface_hub import HfApi
|
| 239 |
+
api = HfApi()
|
| 240 |
try:
|
| 241 |
+
files = api.list_repo_files(repo_id=WEIGHTS_REPO_ID, repo_type="model")
|
| 242 |
+
return set(files)
|
| 243 |
+
except Exception:
|
| 244 |
+
return set()
|
| 245 |
|
| 246 |
|
| 247 |
+
def _remote_has_finetuned(model_dir: str, subdir: str) -> bool:
|
| 248 |
+
files = _weights_repo_file_set()
|
| 249 |
+
if not files:
|
| 250 |
+
return False
|
| 251 |
+
prefix = f"outputs/{model_dir}/{subdir}/"
|
| 252 |
+
return any(path.startswith(prefix) for path in files)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _resolve_weight_paths(
|
| 256 |
+
need_outputs: bool,
|
| 257 |
+
need_shakespeare: bool,
|
| 258 |
+
output_model_dir: str | None = None,
|
| 259 |
+
):
|
| 260 |
output_root = DEFAULT_OUTPUT_ROOT
|
| 261 |
shakespeare_file = DEFAULT_SHAKESPEARE_FILE
|
| 262 |
shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
|
| 263 |
|
| 264 |
+
if need_outputs:
|
| 265 |
+
if output_model_dir:
|
| 266 |
+
local_model = os.path.join(output_root, output_model_dir)
|
| 267 |
+
cache_model = os.path.join(WEIGHTS_CACHE_DIR, "outputs", output_model_dir)
|
| 268 |
+
have_outputs = (
|
| 269 |
+
(os.path.isdir(local_model) and len(os.listdir(local_model)) > 0)
|
| 270 |
+
or (os.path.isdir(cache_model) and len(os.listdir(cache_model)) > 0)
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
have_outputs = os.path.isdir(output_root) and len(os.listdir(output_root)) > 0
|
| 274 |
+
else:
|
| 275 |
+
have_outputs = True
|
| 276 |
+
|
| 277 |
have_shakespeare = (
|
| 278 |
os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights)
|
| 279 |
)
|
|
|
|
| 281 |
return output_root, shakespeare_file, shakespeare_weights
|
| 282 |
|
| 283 |
try:
|
| 284 |
+
cache_dir = _download_weights(
|
| 285 |
+
need_outputs,
|
| 286 |
+
need_shakespeare,
|
| 287 |
+
output_model_dir=output_model_dir,
|
| 288 |
+
)
|
| 289 |
candidate_output_root = os.path.join(cache_dir, "outputs")
|
| 290 |
candidate_shakespeare_file = os.path.join(cache_dir, "input.txt")
|
| 291 |
candidate_shakespeare_weights = os.path.join(
|
|
|
|
| 326 |
for path in candidates:
|
| 327 |
if os.path.isdir(path) and len(os.listdir(path)) > 0:
|
| 328 |
return True
|
| 329 |
+
return _remote_has_finetuned(model_dir, subdir)
|
| 330 |
|
| 331 |
|
| 332 |
def _ckpt_path(output_root, model_dir, subdir):
|
|
|
|
| 347 |
_resolve_weight_paths(
|
| 348 |
need_outputs=True,
|
| 349 |
need_shakespeare=(model_dir == "custom_vlm"),
|
| 350 |
+
output_model_dir=model_dir,
|
| 351 |
)
|
| 352 |
if _has_finetuned(model_dir, requested_source):
|
| 353 |
return requested_source, None
|
|
|
|
| 361 |
model_dir = MODEL_DIR.get(model_name)
|
| 362 |
if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
|
| 363 |
return False
|
|
|
|
|
|
|
|
|
|
| 364 |
return _has_finetuned(model_dir, requested_source)
|
| 365 |
|
| 366 |
|
|
|
|
| 379 |
|
| 380 |
if weight_source != "base":
|
| 381 |
output_root, _, _ = _resolve_weight_paths(
|
| 382 |
+
need_outputs=True,
|
| 383 |
+
need_shakespeare=False,
|
| 384 |
+
output_model_dir="blip",
|
| 385 |
)
|
| 386 |
ckpt = _ckpt_path(output_root, "blip", weight_source)
|
| 387 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
|
|
| 410 |
|
| 411 |
if weight_source != "base":
|
| 412 |
output_root, _, _ = _resolve_weight_paths(
|
| 413 |
+
need_outputs=True,
|
| 414 |
+
need_shakespeare=False,
|
| 415 |
+
output_model_dir="vit_gpt2",
|
| 416 |
)
|
| 417 |
ckpt = _ckpt_path(output_root, "vit_gpt2", weight_source)
|
| 418 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
|
|
| 437 |
|
| 438 |
if weight_source != "base":
|
| 439 |
output_root, _, _ = _resolve_weight_paths(
|
| 440 |
+
need_outputs=True,
|
| 441 |
+
need_shakespeare=False,
|
| 442 |
+
output_model_dir="git",
|
| 443 |
)
|
| 444 |
ckpt = _ckpt_path(output_root, "git", weight_source)
|
| 445 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
|
|
|
| 461 |
device = get_device()
|
| 462 |
cfg = CFG()
|
| 463 |
output_root, shakespeare_file, shakespeare_weights = _resolve_weight_paths(
|
| 464 |
+
need_outputs=(weight_source != "base"),
|
| 465 |
+
need_shakespeare=True,
|
| 466 |
+
output_model_dir="custom_vlm" if weight_source != "base" else None,
|
| 467 |
)
|
| 468 |
cfg.output_root = output_root
|
| 469 |
cfg.shakespeare_file = shakespeare_file
|
|
|
|
| 542 |
|
| 543 |
if weight_source != "base":
|
| 544 |
output_root, _, _ = _resolve_weight_paths(
|
| 545 |
+
need_outputs=True,
|
| 546 |
+
need_shakespeare=False,
|
| 547 |
+
output_model_dir="blip",
|
| 548 |
)
|
| 549 |
ckpt = _ckpt_path(output_root, "blip", weight_source)
|
| 550 |
if os.path.isdir(ckpt) and os.listdir(ckpt):
|