File size: 2,919 Bytes
87f8e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import joblib
import threading
import traceback
from huggingface_hub import hf_hub_download

# --- CONFIGURATION ---
HF_REPO_ID = "CodebaseAi/netraids-ml-models"  # Replace with your actual public repo ID
# ---------------------

ACTIVE_MODEL = "bcc"
_ACTIVE_LOCK = threading.Lock()
_MODEL_CACHE = {}

# 1. FIXED PATH LOGIC:
# __file__ is /app/utils/model_selector.py
# dirname(__file__) is /app/utils
# dirname(dirname(...)) is /app (the ROOT)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ML_DIR = os.path.join(BASE_DIR, "ml_models")

# Ensure the local ml_models directory exists for caching
if not os.path.exists(ML_DIR):
    os.makedirs(ML_DIR, exist_ok=True)

print(f"[model_selector] ROOT BASE_DIR: {BASE_DIR}")
print(f"[model_selector] ML_DIR: {ML_DIR}")

def _get_model_path(filename):
    """
    First looks in the local 'ml_models' folder. 
    If not found, downloads from the public Hugging Face Hub.
    """
    local_path = os.path.join(ML_DIR, filename)
    
    # 1. Check if the file is already there
    if os.path.exists(local_path):
        return local_path
    
    # 2. Download from Hub if missing
    try:
        print(f"[model_selector] {filename} not found locally. Downloading from Hub...")
        # We specify local_dir to force it into our ml_models folder
        downloaded_path = hf_hub_download(
            repo_id=HF_REPO_ID, 
            filename=filename,
            local_dir=ML_DIR
        )
        return downloaded_path
    except Exception as e:
        print(f"[model_selector] ERROR: Could not find/download {filename}: {e}")
        return None

def _try_load(filename):
    path = _get_model_path(filename)
    if not path or not os.path.exists(path):
        print(f"[model_selector] SKIP: {filename} path invalid.")
        return None
    try:
        return joblib.load(path)
    except Exception as e:
        print(f"[model_selector] FAILED to load {filename}: {e}")
        return None

def load_model(model_key):
    if model_key in _MODEL_CACHE:
        return _MODEL_CACHE[model_key]

    if model_key == "bcc":
        _MODEL_CACHE["bcc"] = {
            "model": _try_load("realtime_model.pkl"),
            "scaler": _try_load("realtime_scaler.pkl"),
            "encoder": _try_load("realtime_encoder.pkl")
        }
        return _MODEL_CACHE["bcc"]

    if model_key == "cicids":
        # It will look for your RF files in the Hub
        _MODEL_CACHE["cicids"] = {
            "model": _try_load("rf_pipeline.joblib"),
            "artifacts": _try_load("training_artifacts.joblib")
        }
        return _MODEL_CACHE["cicids"]

    raise ValueError(f"Unknown model_key: {model_key}")

def set_active_model(key: str):
    global ACTIVE_MODEL
    with _ACTIVE_LOCK:
        ACTIVE_MODEL = key
    print(f"[model_selector] ACTIVE_MODEL set to: {ACTIVE_MODEL}")

def get_active_model():
    return ACTIVE_MODEL