Spaces:
Sleeping
Sleeping
File size: 4,395 Bytes
87f8e11 aab61a1 26a71ad 87f8e11 26a71ad ee1c180 26a71ad ee1c180 aab61a1 ee1c180 aab61a1 87f8e11 aab61a1 87f8e11 aab61a1 87f8e11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import joblib
import threading
import traceback
from huggingface_hub import hf_hub_download
import sklearn.utils
# --- CONFIGURATION ---
HF_REPO_ID = "CodebaseAi/netraids-ml-models" # Replace with your actual public repo ID
# ---------------------
ACTIVE_MODEL = "bcc"
_ACTIVE_LOCK = threading.Lock()
_MODEL_CACHE = {}
# --- THE "BULLETPROOF" SKLEARN PATCH ---
# We must do this BEFORE any other ML imports
import sklearn.utils
try:
import sklearn.utils._column_transformer as ct_utils
sklearn.utils._get_column_indices = ct_utils._get_column_indices
print("[Patch] Successfully injected _get_column_indices")
except Exception as e:
# If the above fails, we define a dummy function to stop the crash
def _get_column_indices(X, key):
from sklearn.utils._column_transformer import _get_column_indices as gci
return gci(X, key)
sklearn.utils._get_column_indices = _get_column_indices
print(f"[Patch] Manual injection fallback used: {e}")
# Patch for parse_version
if not hasattr(sklearn.utils, 'parse_version'):
try:
from packaging import version
sklearn.utils.parse_version = version.parse
print("[Patch] Successfully injected parse_version")
except ImportError:
print("[Patch] 'packaging' library missing. Ensure it is in requirements.txt")
# ---------------------------------------
# --------------------------------------------
# 1. FIXED PATH LOGIC:
# __file__ is /app/utils/model_selector.py
# dirname(__file__) is /app/utils
# dirname(dirname(...)) is /app (the ROOT)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ML_DIR = os.path.join(BASE_DIR, "ml_models")
# Ensure the local ml_models directory exists for caching
if not os.path.exists(ML_DIR):
os.makedirs(ML_DIR, exist_ok=True)
print(f"[model_selector] ROOT BASE_DIR: {BASE_DIR}")
print(f"[model_selector] ML_DIR: {ML_DIR}")
def _get_model_path(filename):
"""
First looks in the local 'ml_models' folder.
If not found, downloads from the public Hugging Face Hub.
"""
local_path = os.path.join(ML_DIR, filename)
# 1. Check if the file is already there
if os.path.exists(local_path):
return local_path
# 2. Download from Hub if missing
try:
print(f"[model_selector] {filename} not found locally. Downloading from Hub...")
# We specify local_dir to force it into our ml_models folder
downloaded_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=ML_DIR
)
return downloaded_path
except Exception as e:
print(f"[model_selector] ERROR: Could not find/download {filename}: {e}")
return None
def _try_load(filename):
path = _get_model_path(filename)
if not path or not os.path.exists(path):
print(f"[model_selector] SKIP: {filename} path invalid.")
return None
try:
# Check if file size is > 0 before loading
if os.path.getsize(path) == 0:
print(f"[model_selector] ERROR: {filename} is an empty file.")
return None
print(f"[model_selector] Attempting joblib.load for {filename}")
return joblib.load(path)
except Exception as e:
print(f"[model_selector] CRITICAL FAILED to load {filename}")
print(traceback.format_exc()) # This will show exactly why in HF logs
return None
def load_model(model_key):
if model_key in _MODEL_CACHE:
return _MODEL_CACHE[model_key]
if model_key == "bcc":
_MODEL_CACHE["bcc"] = {
"model": _try_load("realtime_model.pkl"),
"scaler": _try_load("realtime_scaler.pkl"),
"encoder": _try_load("realtime_encoder.pkl")
}
return _MODEL_CACHE["bcc"]
if model_key == "cicids":
# It will look for your RF files in the Hub
_MODEL_CACHE["cicids"] = {
"model": _try_load("rf_pipeline.joblib"),
"artifacts": _try_load("training_artifacts.joblib")
}
return _MODEL_CACHE["cicids"]
raise ValueError(f"Unknown model_key: {model_key}")
def set_active_model(key: str):
global ACTIVE_MODEL
with _ACTIVE_LOCK:
ACTIVE_MODEL = key
print(f"[model_selector] ACTIVE_MODEL set to: {ACTIVE_MODEL}")
def get_active_model():
return ACTIVE_MODEL
|