Alibrown commited on
Commit
dc2f1b4
·
verified ·
1 Parent(s): 4ee3607

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -7
model.py CHANGED
@@ -57,16 +57,26 @@ def get_api() -> Optional[HfApi]:
57
  def get_model_id() -> str:
58
  """
59
  Returns model ID to load.
60
- Prefers private fine-tuned model if available, falls back to base model.
 
61
  """
62
  api = get_api()
63
  if api and PRIVATE_MODEL:
64
  try:
65
- api.model_info(PRIVATE_MODEL, token=TOKEN)
66
- logger.info(f"Using private model: {PRIVATE_MODEL}")
67
- return PRIVATE_MODEL
68
- except Exception:
69
- logger.info(f"Private model not ready — using base: {MODEL_REPO}")
 
 
 
 
 
 
 
 
 
70
  return MODEL_REPO
71
 
72
 
@@ -176,4 +186,4 @@ def status() -> dict:
176
  "private_model": PRIVATE_MODEL,
177
  "dataset_repo": DATASET_REPO,
178
  "hf_api": "authenticated" if get_api() else "unauthenticated",
179
- }
 
57
  def get_model_id() -> str:
58
  """
59
  Returns model ID to load.
60
+ Prefers private fine-tuned model only if it has actual weights (config.json with model_type).
61
+ Falls back to base model if private repo is empty or not ready.
62
  """
63
  api = get_api()
64
  if api and PRIVATE_MODEL:
65
  try:
66
+ files = api.list_repo_files(PRIVATE_MODEL, repo_type="model", token=TOKEN)
67
+ has_config = "config.json" in list(files)
68
+ if has_config:
69
+ # Double-check it's a real model config, not just a README
70
+ from huggingface_hub import hf_hub_download
71
+ import json
72
+ cfg_path = hf_hub_download(PRIVATE_MODEL, "config.json", token=TOKEN)
73
+ cfg = json.loads(open(cfg_path).read())
74
+ if "model_type" in cfg:
75
+ logger.info(f"Using private model: {PRIVATE_MODEL}")
76
+ return PRIVATE_MODEL
77
+ logger.info(f"Private repo exists but has no weights yet — using base: {MODEL_REPO}")
78
+ except Exception as e:
79
+ logger.info(f"Private model check failed ({type(e).__name__}) — using base: {MODEL_REPO}")
80
  return MODEL_REPO
81
 
82
 
 
186
  "private_model": PRIVATE_MODEL,
187
  "dataset_repo": DATASET_REPO,
188
  "hf_api": "authenticated" if get_api() else "unauthenticated",
189
+ }