boking-cancelation-api / app /model_loader.py
j2damax's picture
Deploy app with models and artifacts (force update)
575d22a verified
"""Model + preprocessor loading utilities (local + Hugging Face Hub).
S3 support removed as project no longer uses AWS. Loading order:
1. Local baked artifacts (if present)
2. Hugging Face Hub (HF_MODEL_REPO)
Environment variables quick reference:
HF_MODEL_REPO remote repo to pull snapshot from (e.g. j2damax/hotel-cancel-model)
FORCE_HF_LOAD=true always fetch from HF even if local artifacts exist
HF_HUB_CACHE=path optional override for huggingface hub cache directory
ALLOW_START_WITHOUT_MODEL=true allow API to start even when model load fails
DECISION_THRESHOLD manual override for probability decision threshold
"""
from __future__ import annotations
import os, json, time
import joblib
from typing import Optional, Tuple
from . import config
from src.preprocessing import PreprocessingPipeline
model = None
preprocessor: Optional[PreprocessingPipeline] = None
model_version: Optional[str] = None
champion_meta_threshold: Optional[float] = None
_last_reload_time: float | None = None
def _resolve_git_sha() -> str | None:
git_sha = os.getenv('GIT_SHA')
if git_sha:
return git_sha[:12]
head_path = os.path.join('.git','HEAD')
try:
if os.path.exists(head_path):
with open(head_path) as hf:
ref = hf.read().strip()
if ref.startswith('ref:'):
ref_file = ref.split(' ',1)[1]
ref_path = os.path.join('.git', ref_file)
if os.path.exists(ref_path):
with open(ref_path) as rf:
return rf.read().strip()[:12]
else:
return ref[:12]
except Exception:
return None
return None
def load_model() -> None:
"""Idempotent loading routine.
Order (unless overridden):
1. Local baked artifacts (if present)
2. Hugging Face Hub snapshot (HF_MODEL_REPO)
Env flags:
FORCE_HF_LOAD=true -> Skip local artifacts and always pull from HF
HF_HUB_CACHE=/custom -> Redirect huggingface hub cache (optional; we still store snapshot in models/hf/...)
"""
global model, preprocessor, model_version, champion_meta_threshold, _last_reload_time
force_hf = os.getenv('FORCE_HF_LOAD', 'false').lower() == 'true'
hf_repo = getattr(config, 'HF_MODEL_REPO', None)
# Optional hub-wide cache redirection for environments with restricted root perms
hf_hub_cache = os.getenv('HF_HUB_CACHE')
if hf_hub_cache:
try:
os.makedirs(hf_hub_cache, exist_ok=True)
os.environ['HF_HOME'] = hf_hub_cache # respected by huggingface_hub
except Exception as e:
print(f"Could not create HF_HUB_CACHE directory {hf_hub_cache}: {e}")
# Local load (unless forcing HF)
if not force_hf:
if model is None and os.path.exists(config.LOCAL_MODEL_PATH):
try:
model_candidate = joblib.load(config.LOCAL_MODEL_PATH)
if hasattr(model_candidate, 'predict'):
model = model_candidate
mtime = int(os.path.getmtime(config.LOCAL_MODEL_PATH))
model_version = f"local_{mtime}"
except Exception as e:
print(f"Local model load failed: {e}")
if preprocessor is None and os.path.exists(config.LOCAL_PREPROCESSOR_PATH):
try:
preprocessor = PreprocessingPipeline.load(config.LOCAL_PREPROCESSOR_PATH)
except Exception:
preprocessor = None
else:
print("FORCE_HF_LOAD=true -> Skipping local artifact loading")
# Hugging Face Hub load if still missing or forced
need_hf = (model is None or preprocessor is None or force_hf) and hf_repo
if need_hf:
try:
from huggingface_hub import snapshot_download
repo_id = hf_repo
cache_dir = os.path.join('models','hf', repo_id.replace('/','__'))
os.makedirs(cache_dir, exist_ok=True)
local_dir = snapshot_download(repo_id=repo_id, local_dir=cache_dir, local_dir_use_symlinks=False)
model_path = os.path.join(local_dir, 'champion_model.pkl')
preproc_path = os.path.join(local_dir, 'preprocessor.pkl')
meta_path = os.path.join(local_dir, 'champion_meta.json')
if (model is None or force_hf) and os.path.exists(model_path):
try:
m_candidate = joblib.load(model_path)
if hasattr(m_candidate, 'predict'):
model = m_candidate
model_version = f"hf_{os.path.getmtime(model_path):.0f}"
except Exception as e:
print(f"HF model load failed: {e}")
if (preprocessor is None or force_hf) and os.path.exists(preproc_path):
try:
preprocessor = PreprocessingPipeline.load(preproc_path)
except Exception:
preprocessor = None
if os.path.exists(meta_path):
try:
with open(meta_path) as mf:
meta = json.load(mf)
if 'decision_threshold' in meta:
champion_meta_threshold = meta['decision_threshold']
except Exception:
pass
if model is not None and force_hf:
print(f"Loaded model (HF FORCE) repo={repo_id} version={model_version}")
elif model is not None:
print(f"Loaded model (HF) repo={repo_id} version={model_version}")
except Exception as e:
print(f"HF load failed: {e}")
if model is None:
print("No model loaded (checked local + HF). API will report model_not_loaded.")
def resolve_threshold() -> tuple[float, str]:
if config.DECISION_THRESHOLD_ENV is not None:
try:
return float(config.DECISION_THRESHOLD_ENV), 'env'
except ValueError:
pass
if champion_meta_threshold is not None:
try:
return float(champion_meta_threshold), 'champion_meta'
except ValueError:
pass
return 0.5, 'default'
def load_model_and_preprocessor():
"""Convenience helper to ensure artifacts are loaded and return them with minimal metadata.
Returns (model, preprocessor, metadata_dict)
metadata_dict contains keys: version, threshold, threshold_source
"""
if model is None or preprocessor is None:
load_model()
thr, source = resolve_threshold()
meta = {
'version': model_version,
'threshold': thr,
'threshold_source': source
}
return model, preprocessor, meta