File size: 4,395 Bytes
87f8e11
 
 
 
 
aab61a1
26a71ad
 
87f8e11
 
 
 
 
 
 
 
26a71ad
ee1c180
 
 
26a71ad
 
ee1c180
 
 
 
 
 
 
 
 
 
 
aab61a1
ee1c180
 
 
 
 
 
 
aab61a1
 
87f8e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aab61a1
 
 
 
 
 
87f8e11
 
aab61a1
 
87f8e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import joblib
import threading
import traceback
from huggingface_hub import hf_hub_download
import sklearn.utils


# --- CONFIGURATION ---
HF_REPO_ID = "CodebaseAi/netraids-ml-models"  # Replace with your actual public repo ID
# ---------------------

ACTIVE_MODEL = "bcc"
_ACTIVE_LOCK = threading.Lock()
_MODEL_CACHE = {}


# --- THE "BULLETPROOF" SKLEARN PATCH ---
# We must do this BEFORE any other ML imports
import sklearn.utils
try:
    import sklearn.utils._column_transformer as ct_utils
    sklearn.utils._get_column_indices = ct_utils._get_column_indices
    print("[Patch] Successfully injected _get_column_indices")
except Exception as e:
    # If the above fails, we define a dummy function to stop the crash
    def _get_column_indices(X, key):
        from sklearn.utils._column_transformer import _get_column_indices as gci
        return gci(X, key)
    sklearn.utils._get_column_indices = _get_column_indices
    print(f"[Patch] Manual injection fallback used: {e}")

# Patch for parse_version
if not hasattr(sklearn.utils, 'parse_version'):
    try:
        from packaging import version
        sklearn.utils.parse_version = version.parse
        print("[Patch] Successfully injected parse_version")
    except ImportError:
        print("[Patch] 'packaging' library missing. Ensure it is in requirements.txt")
# ---------------------------------------
# --------------------------------------------

# 1. FIXED PATH LOGIC:
# __file__ is /app/utils/model_selector.py
# dirname(__file__) is /app/utils
# dirname(dirname(...)) is /app (the ROOT)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ML_DIR = os.path.join(BASE_DIR, "ml_models")

# Ensure the local ml_models directory exists for caching
if not os.path.exists(ML_DIR):
    os.makedirs(ML_DIR, exist_ok=True)

print(f"[model_selector] ROOT BASE_DIR: {BASE_DIR}")
print(f"[model_selector] ML_DIR: {ML_DIR}")

def _get_model_path(filename):
    """
    First looks in the local 'ml_models' folder. 
    If not found, downloads from the public Hugging Face Hub.
    """
    local_path = os.path.join(ML_DIR, filename)
    
    # 1. Check if the file is already there
    if os.path.exists(local_path):
        return local_path
    
    # 2. Download from Hub if missing
    try:
        print(f"[model_selector] {filename} not found locally. Downloading from Hub...")
        # We specify local_dir to force it into our ml_models folder
        downloaded_path = hf_hub_download(
            repo_id=HF_REPO_ID, 
            filename=filename,
            local_dir=ML_DIR
        )
        return downloaded_path
    except Exception as e:
        print(f"[model_selector] ERROR: Could not find/download {filename}: {e}")
        return None

def _try_load(filename):
    path = _get_model_path(filename)
    if not path or not os.path.exists(path):
        print(f"[model_selector] SKIP: {filename} path invalid.")
        return None
    try:
        # Check if file size is > 0 before loading
        if os.path.getsize(path) == 0:
            print(f"[model_selector] ERROR: {filename} is an empty file.")
            return None
            
        print(f"[model_selector] Attempting joblib.load for {filename}")
        return joblib.load(path)
    except Exception as e:
        print(f"[model_selector] CRITICAL FAILED to load {filename}")
        print(traceback.format_exc()) # This will show exactly why in HF logs
        return None

def load_model(model_key):
    if model_key in _MODEL_CACHE:
        return _MODEL_CACHE[model_key]

    if model_key == "bcc":
        _MODEL_CACHE["bcc"] = {
            "model": _try_load("realtime_model.pkl"),
            "scaler": _try_load("realtime_scaler.pkl"),
            "encoder": _try_load("realtime_encoder.pkl")
        }
        return _MODEL_CACHE["bcc"]

    if model_key == "cicids":
        # It will look for your RF files in the Hub
        _MODEL_CACHE["cicids"] = {
            "model": _try_load("rf_pipeline.joblib"),
            "artifacts": _try_load("training_artifacts.joblib")
        }
        return _MODEL_CACHE["cicids"]

    raise ValueError(f"Unknown model_key: {model_key}")

def set_active_model(key: str):
    global ACTIVE_MODEL
    with _ACTIVE_LOCK:
        ACTIVE_MODEL = key
    print(f"[model_selector] ACTIVE_MODEL set to: {ACTIVE_MODEL}")

def get_active_model():
    return ACTIVE_MODEL