File size: 3,279 Bytes
31fda96
 
 
0117df3
31fda96
 
0117df3
31fda96
 
 
 
 
 
8d28be7
 
0117df3
 
8d28be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0117df3
183f1c4
0117df3
31fda96
 
 
8d28be7
 
31fda96
0117df3
 
 
 
31fda96
 
8d28be7
 
 
 
0117df3
8d28be7
 
 
 
0117df3
 
 
31fda96
 
8d28be7
 
 
 
 
 
 
 
 
 
 
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
8d28be7
31fda96
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import logging
import pickle
import shutil
from pathlib import Path

import torch
from huggingface_hub import snapshot_download

from config import Config

REPO_ID = Config.REPO_ID_LANG
MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None
HF_TOKEN = Config.HF_TOKEN
ENGLISH_SUBDIR = "English_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

REQUIRED_FILES = (
    "classifier.pkl",
    "scaler.pkl",
    "word_vectorizer.pkl",
    "char_vectorizer.pkl",
    "feature_names.json",
    "metadata.json",
)


def _has_required_artifacts(model_dir: Path) -> bool:
    if not model_dir.exists() or not model_dir.is_dir():
        return False
    return all((model_dir / filename).exists() for filename in REQUIRED_FILES)


def _resolve_artifact_dir(base_dir: Path) -> Path | None:
    candidates = [base_dir, base_dir / ENGLISH_SUBDIR]
    for candidate in candidates:
        if _has_required_artifacts(candidate):
            return candidate
    return None


def warmup():
    logging.info("Warming up model...")
    if MODEL_DIR is None:
        raise ValueError("LANG_MODEL is not configured")
    if _resolve_artifact_dir(MODEL_DIR):
        logging.info("Model artifacts already exist, skipping download.")
        return
    download_model_repo()


def download_model_repo():
    if MODEL_DIR is None:
        raise ValueError("LANG_MODEL is not configured")
    if not REPO_ID:
        raise ValueError("English_model repo id is not configured")
    if _resolve_artifact_dir(MODEL_DIR):
        logging.info("Model artifacts already exist, skipping download.")
        return
    snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN))
    source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True)


def load_model():
    if MODEL_DIR is None:
        raise ValueError("LANG_MODEL is not configured")
    artifact_dir = _resolve_artifact_dir(MODEL_DIR)
    if artifact_dir is None:
        logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR)
        download_model_repo()
        artifact_dir = _resolve_artifact_dir(MODEL_DIR)
    if artifact_dir is None:
        raise FileNotFoundError(
            f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}"
        )

    with open(artifact_dir / "classifier.pkl", "rb") as f:
        loaded_classifier = pickle.load(f)

    with open(artifact_dir / "scaler.pkl", "rb") as f:
        loaded_scaler = pickle.load(f)

    with open(artifact_dir / "word_vectorizer.pkl", "rb") as f:
        loaded_word_vectorizer = pickle.load(f)

    with open(artifact_dir / "char_vectorizer.pkl", "rb") as f:
        loaded_char_vectorizer = pickle.load(f)

    with open(artifact_dir / "feature_names.json", "r") as f:
        loaded_features = json.load(f)

    with open(artifact_dir / "metadata.json", "r") as f:
        loaded_metadata = json.load(f)
    return (
        loaded_classifier,
        loaded_scaler,
        loaded_word_vectorizer,
        loaded_char_vectorizer,
        loaded_features,
        loaded_metadata,
    )