Spaces:
Sleeping
Sleeping
File size: 6,725 Bytes
575d22a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""Model + preprocessor loading utilities (local + Hugging Face Hub).
S3 support removed as project no longer uses AWS. Loading order:
1. Local baked artifacts (if present)
2. Hugging Face Hub (HF_MODEL_REPO)
Environment variables quick reference:
HF_MODEL_REPO remote repo to pull snapshot from (e.g. j2damax/hotel-cancel-model)
FORCE_HF_LOAD=true always fetch from HF even if local artifacts exist
HF_HUB_CACHE=path optional override for huggingface hub cache directory
ALLOW_START_WITHOUT_MODEL=true allow API to start even when model load fails
DECISION_THRESHOLD manual override for probability decision threshold
"""
from __future__ import annotations
import os, json, time
import joblib
from typing import Optional, Tuple
from . import config
from src.preprocessing import PreprocessingPipeline
model = None
preprocessor: Optional[PreprocessingPipeline] = None
model_version: Optional[str] = None
champion_meta_threshold: Optional[float] = None
_last_reload_time: float | None = None
def _resolve_git_sha() -> str | None:
git_sha = os.getenv('GIT_SHA')
if git_sha:
return git_sha[:12]
head_path = os.path.join('.git','HEAD')
try:
if os.path.exists(head_path):
with open(head_path) as hf:
ref = hf.read().strip()
if ref.startswith('ref:'):
ref_file = ref.split(' ',1)[1]
ref_path = os.path.join('.git', ref_file)
if os.path.exists(ref_path):
with open(ref_path) as rf:
return rf.read().strip()[:12]
else:
return ref[:12]
except Exception:
return None
return None
def load_model() -> None:
"""Idempotent loading routine.
Order (unless overridden):
1. Local baked artifacts (if present)
2. Hugging Face Hub snapshot (HF_MODEL_REPO)
Env flags:
FORCE_HF_LOAD=true -> Skip local artifacts and always pull from HF
HF_HUB_CACHE=/custom -> Redirect huggingface hub cache (optional; we still store snapshot in models/hf/...)
"""
global model, preprocessor, model_version, champion_meta_threshold, _last_reload_time
force_hf = os.getenv('FORCE_HF_LOAD', 'false').lower() == 'true'
hf_repo = getattr(config, 'HF_MODEL_REPO', None)
# Optional hub-wide cache redirection for environments with restricted root perms
hf_hub_cache = os.getenv('HF_HUB_CACHE')
if hf_hub_cache:
try:
os.makedirs(hf_hub_cache, exist_ok=True)
os.environ['HF_HOME'] = hf_hub_cache # respected by huggingface_hub
except Exception as e:
print(f"Could not create HF_HUB_CACHE directory {hf_hub_cache}: {e}")
# Local load (unless forcing HF)
if not force_hf:
if model is None and os.path.exists(config.LOCAL_MODEL_PATH):
try:
model_candidate = joblib.load(config.LOCAL_MODEL_PATH)
if hasattr(model_candidate, 'predict'):
model = model_candidate
mtime = int(os.path.getmtime(config.LOCAL_MODEL_PATH))
model_version = f"local_{mtime}"
except Exception as e:
print(f"Local model load failed: {e}")
if preprocessor is None and os.path.exists(config.LOCAL_PREPROCESSOR_PATH):
try:
preprocessor = PreprocessingPipeline.load(config.LOCAL_PREPROCESSOR_PATH)
except Exception:
preprocessor = None
else:
print("FORCE_HF_LOAD=true -> Skipping local artifact loading")
# Hugging Face Hub load if still missing or forced
need_hf = (model is None or preprocessor is None or force_hf) and hf_repo
if need_hf:
try:
from huggingface_hub import snapshot_download
repo_id = hf_repo
cache_dir = os.path.join('models','hf', repo_id.replace('/','__'))
os.makedirs(cache_dir, exist_ok=True)
local_dir = snapshot_download(repo_id=repo_id, local_dir=cache_dir, local_dir_use_symlinks=False)
model_path = os.path.join(local_dir, 'champion_model.pkl')
preproc_path = os.path.join(local_dir, 'preprocessor.pkl')
meta_path = os.path.join(local_dir, 'champion_meta.json')
if (model is None or force_hf) and os.path.exists(model_path):
try:
m_candidate = joblib.load(model_path)
if hasattr(m_candidate, 'predict'):
model = m_candidate
model_version = f"hf_{os.path.getmtime(model_path):.0f}"
except Exception as e:
print(f"HF model load failed: {e}")
if (preprocessor is None or force_hf) and os.path.exists(preproc_path):
try:
preprocessor = PreprocessingPipeline.load(preproc_path)
except Exception:
preprocessor = None
if os.path.exists(meta_path):
try:
with open(meta_path) as mf:
meta = json.load(mf)
if 'decision_threshold' in meta:
champion_meta_threshold = meta['decision_threshold']
except Exception:
pass
if model is not None and force_hf:
print(f"Loaded model (HF FORCE) repo={repo_id} version={model_version}")
elif model is not None:
print(f"Loaded model (HF) repo={repo_id} version={model_version}")
except Exception as e:
print(f"HF load failed: {e}")
if model is None:
print("No model loaded (checked local + HF). API will report model_not_loaded.")
def resolve_threshold() -> tuple[float, str]:
if config.DECISION_THRESHOLD_ENV is not None:
try:
return float(config.DECISION_THRESHOLD_ENV), 'env'
except ValueError:
pass
if champion_meta_threshold is not None:
try:
return float(champion_meta_threshold), 'champion_meta'
except ValueError:
pass
return 0.5, 'default'
def load_model_and_preprocessor():
"""Convenience helper to ensure artifacts are loaded and return them with minimal metadata.
Returns (model, preprocessor, metadata_dict)
metadata_dict contains keys: version, threshold, threshold_source
"""
if model is None or preprocessor is None:
load_model()
thr, source = resolve_threshold()
meta = {
'version': model_version,
'threshold': thr,
'threshold_source': source
}
return model, preprocessor, meta
|