File size: 3,279 Bytes
31fda96 0117df3 31fda96 0117df3 31fda96 8d28be7 0117df3 8d28be7 0117df3 183f1c4 0117df3 31fda96 8d28be7 31fda96 0117df3 31fda96 8d28be7 0117df3 8d28be7 0117df3 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 8d28be7 31fda96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import json
import logging
import pickle
import shutil
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from config import Config
REPO_ID = Config.REPO_ID_LANG
MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None
HF_TOKEN = Config.HF_TOKEN
ENGLISH_SUBDIR = "English_model"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REQUIRED_FILES = (
"classifier.pkl",
"scaler.pkl",
"word_vectorizer.pkl",
"char_vectorizer.pkl",
"feature_names.json",
"metadata.json",
)
def _has_required_artifacts(model_dir: Path) -> bool:
if not model_dir.exists() or not model_dir.is_dir():
return False
return all((model_dir / filename).exists() for filename in REQUIRED_FILES)
def _resolve_artifact_dir(base_dir: Path) -> Path | None:
candidates = [base_dir, base_dir / ENGLISH_SUBDIR]
for candidate in candidates:
if _has_required_artifacts(candidate):
return candidate
return None
def warmup():
logging.info("Warming up model...")
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
if _resolve_artifact_dir(MODEL_DIR):
logging.info("Model artifacts already exist, skipping download.")
return
download_model_repo()
def download_model_repo():
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
if not REPO_ID:
raise ValueError("English_model repo id is not configured")
if _resolve_artifact_dir(MODEL_DIR):
logging.info("Model artifacts already exist, skipping download.")
return
snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN))
source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path
MODEL_DIR.mkdir(parents=True, exist_ok=True)
shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True)
def load_model():
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
artifact_dir = _resolve_artifact_dir(MODEL_DIR)
if artifact_dir is None:
logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR)
download_model_repo()
artifact_dir = _resolve_artifact_dir(MODEL_DIR)
if artifact_dir is None:
raise FileNotFoundError(
f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}"
)
with open(artifact_dir / "classifier.pkl", "rb") as f:
loaded_classifier = pickle.load(f)
with open(artifact_dir / "scaler.pkl", "rb") as f:
loaded_scaler = pickle.load(f)
with open(artifact_dir / "word_vectorizer.pkl", "rb") as f:
loaded_word_vectorizer = pickle.load(f)
with open(artifact_dir / "char_vectorizer.pkl", "rb") as f:
loaded_char_vectorizer = pickle.load(f)
with open(artifact_dir / "feature_names.json", "r") as f:
loaded_features = json.load(f)
with open(artifact_dir / "metadata.json", "r") as f:
loaded_metadata = json.load(f)
return (
loaded_classifier,
loaded_scaler,
loaded_word_vectorizer,
loaded_char_vectorizer,
loaded_features,
loaded_metadata,
)
|