Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -57,16 +57,26 @@ def get_api() -> Optional[HfApi]:
|
|
| 57 |
def get_model_id() -> str:
|
| 58 |
"""
|
| 59 |
Returns model ID to load.
|
| 60 |
-
Prefers private fine-tuned model if
|
|
|
|
| 61 |
"""
|
| 62 |
api = get_api()
|
| 63 |
if api and PRIVATE_MODEL:
|
| 64 |
try:
|
| 65 |
-
api.
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
return MODEL_REPO
|
| 71 |
|
| 72 |
|
|
@@ -176,4 +186,4 @@ def status() -> dict:
|
|
| 176 |
"private_model": PRIVATE_MODEL,
|
| 177 |
"dataset_repo": DATASET_REPO,
|
| 178 |
"hf_api": "authenticated" if get_api() else "unauthenticated",
|
| 179 |
-
}
|
|
|
|
| 57 |
def get_model_id() -> str:
|
| 58 |
"""
|
| 59 |
Returns model ID to load.
|
| 60 |
+
Prefers private fine-tuned model only if it has actual weights (config.json with model_type).
|
| 61 |
+
Falls back to base model if private repo is empty or not ready.
|
| 62 |
"""
|
| 63 |
api = get_api()
|
| 64 |
if api and PRIVATE_MODEL:
|
| 65 |
try:
|
| 66 |
+
files = api.list_repo_files(PRIVATE_MODEL, repo_type="model", token=TOKEN)
|
| 67 |
+
has_config = "config.json" in list(files)
|
| 68 |
+
if has_config:
|
| 69 |
+
# Double-check it's a real model config, not just a README
|
| 70 |
+
from huggingface_hub import hf_hub_download
|
| 71 |
+
import json
|
| 72 |
+
cfg_path = hf_hub_download(PRIVATE_MODEL, "config.json", token=TOKEN)
|
| 73 |
+
cfg = json.loads(open(cfg_path).read())
|
| 74 |
+
if "model_type" in cfg:
|
| 75 |
+
logger.info(f"Using private model: {PRIVATE_MODEL}")
|
| 76 |
+
return PRIVATE_MODEL
|
| 77 |
+
logger.info(f"Private repo exists but has no weights yet — using base: {MODEL_REPO}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.info(f"Private model check failed ({type(e).__name__}) — using base: {MODEL_REPO}")
|
| 80 |
return MODEL_REPO
|
| 81 |
|
| 82 |
|
|
|
|
| 186 |
"private_model": PRIVATE_MODEL,
|
| 187 |
"dataset_repo": DATASET_REPO,
|
| 188 |
"hf_api": "authenticated" if get_api() else "unauthenticated",
|
| 189 |
+
}
|