Spaces:
Runtime error
Runtime error
Update train_and_save_models.py
Browse files- train_and_save_models.py +37 -16
train_and_save_models.py
CHANGED
|
@@ -10,9 +10,9 @@ from pathlib import Path
|
|
| 10 |
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 11 |
import numpy as np
|
| 12 |
import multiprocessing
|
| 13 |
-
import joblib # Importado para salvar
|
| 14 |
|
| 15 |
-
#
|
| 16 |
# from flask import Flask, request, jsonify
|
| 17 |
# from flask_cors import CORS
|
| 18 |
|
|
@@ -46,11 +46,13 @@ torch.manual_seed(RANDOM_SEED)
|
|
| 46 |
if torch.cuda.is_available():
|
| 47 |
torch.cuda.manual_seed_all(RANDOM_SEED)
|
| 48 |
|
| 49 |
-
model_base = None
|
| 50 |
mlp_regressor, scaler = None, None
|
| 51 |
tfidf_vectorizer, tfidf_regressor = None, None
|
| 52 |
|
| 53 |
-
# --- Vocabulário Expandido (compartilhado) ---
|
|
|
|
|
|
|
| 54 |
ADVERSARIAL_RISK_ACTORS = [
|
| 55 |
"Unsandboxed process", "Leaked API key", "Misconfigured service account", "Shadow IT application",
|
| 56 |
"Dormant user account", "Ransomware payload", "Phishing attempt", "Insider threat",
|
|
@@ -237,23 +239,23 @@ ADVERSARIAL_SAFE_TARGETS = [
|
|
| 237 |
"Kubernetes cluster security posture", "Docker container security configuration", "AWS cloud infrastructure",
|
| 238 |
"Azure cloud resources", "GCP cloud services", "container orchestration security",
|
| 239 |
"serverless function security", "cloud API security", "microservice security architecture",
|
| 240 |
-
"container registry security", "cloud logging security", "infrastructure as code
|
| 241 |
"Git repository security", "CI/CD pipeline security", "Docker image security",
|
| 242 |
"artifact repository security", "infrastructure provisioning security", "secret management vault",
|
| 243 |
"code signing certificate store", "dependency management system", "deployment automation platform",
|
| 244 |
-
"build environment
|
| 245 |
"industrial control system security", "SCADA system security", "IoT device security",
|
| 246 |
"edge computing security", "smart city infrastructure security", "medical device network",
|
| 247 |
"automotive system security", "home automation security", "sensor security",
|
| 248 |
-
"industrial protocol
|
| 249 |
-
"enterprise mobile device
|
| 250 |
-
"mobile banking
|
| 251 |
-
"BYOD policy
|
| 252 |
-
"mobile device fingerprinting
|
| 253 |
-
"network segmentation
|
| 254 |
-
"
|
| 255 |
-
"
|
| 256 |
-
"
|
| 257 |
]
|
| 258 |
ADVERSARIAL_SAFE_OUTCOMES = [
|
| 259 |
"all tests passed, security posture confirmed", "the configuration was hardened as per policy",
|
|
@@ -299,6 +301,7 @@ LOW_RISK_KEYWORDS = {
|
|
| 299 |
'backup completed': -20, 'schema migration successful': -15, 'network policy updated': -10
|
| 300 |
}
|
| 301 |
|
|
|
|
| 302 |
def generate_event_text_for_training(is_risk: bool) -> tuple[str, float]:
|
| 303 |
if is_risk:
|
| 304 |
actor = random.choice(ADVERSARIAL_RISK_ACTORS)
|
|
@@ -351,8 +354,24 @@ def populate_database_initial():
|
|
| 351 |
conn.close()
|
| 352 |
print("Banco de dados populado inicialmente com sucesso.")
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
def train_and_save_all_models():
|
| 355 |
-
global mlp_regressor, scaler, tfidf_vectorizer, tfidf_regressor,
|
| 356 |
print("Iniciando o treinamento de todos os modelos a partir do banco de dados...")
|
| 357 |
|
| 358 |
conn = sqlite3.connect(DB_NAME)
|
|
@@ -367,6 +386,7 @@ def train_and_save_all_models():
|
|
| 367 |
train_texts = [row[0] for row in train_data]
|
| 368 |
y_train = np.array([row[1] for row in train_data])
|
| 369 |
|
|
|
|
| 370 |
print("1. Treinando modelo de Embedding Profundo (MLPRegressor)...")
|
| 371 |
X_train_embeddings = []
|
| 372 |
|
|
@@ -406,6 +426,7 @@ def train_and_save_all_models():
|
|
| 406 |
|
| 407 |
print(" ... modelo de Embedding Profundo treinado.")
|
| 408 |
|
|
|
|
| 409 |
print("2. Treinando modelo Vetorial Clássico (TF-IDF + Ridge)...")
|
| 410 |
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), min_df=5, max_df=0.7, max_features=10000)
|
| 411 |
X_train_tfidf = tfidf_vectorizer.fit_transform(train_texts)
|
|
|
|
| 10 |
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 11 |
import numpy as np
|
| 12 |
import multiprocessing
|
| 13 |
+
import joblib # Importado para salvar os modelos
|
| 14 |
|
| 15 |
+
# REMOVIDAS importações Flask e CORS, pois este script é apenas para treinamento local.
|
| 16 |
# from flask import Flask, request, jsonify
|
| 17 |
# from flask_cors import CORS
|
| 18 |
|
|
|
|
| 46 |
if torch.cuda.is_available():
|
| 47 |
torch.cuda.manual_seed_all(RANDOM_SEED)
|
| 48 |
|
| 49 |
+
model_base = None # Será inicializado durante o treinamento
|
| 50 |
mlp_regressor, scaler = None, None
|
| 51 |
tfidf_vectorizer, tfidf_regressor = None, None
|
| 52 |
|
| 53 |
+
# --- Vocabulário Expandido (compartilhado para geração de dados) ---
|
| 54 |
+
# Estas listas são usadas SOMENTE no trainer.py para gerar os dados de treinamento.
|
| 55 |
+
# Elas serão duplicadas (ou uma versão reduzida) no app.py se o "Gerar Evento Aleatório" for mantido.
|
| 56 |
ADVERSARIAL_RISK_ACTORS = [
|
| 57 |
"Unsandboxed process", "Leaked API key", "Misconfigured service account", "Shadow IT application",
|
| 58 |
"Dormant user account", "Ransomware payload", "Phishing attempt", "Insider threat",
|
|
|
|
| 239 |
"Kubernetes cluster security posture", "Docker container security configuration", "AWS cloud infrastructure",
|
| 240 |
"Azure cloud resources", "GCP cloud services", "container orchestration security",
|
| 241 |
"serverless function security", "cloud API security", "microservice security architecture",
|
| 242 |
+
"container registry security", "cloud logging security", "infrastructure as code repository",
|
| 243 |
"Git repository security", "CI/CD pipeline security", "Docker image security",
|
| 244 |
"artifact repository security", "infrastructure provisioning security", "secret management vault",
|
| 245 |
"code signing certificate store", "dependency management system", "deployment automation platform",
|
| 246 |
+
"build environment with elevated privileges", "CI/CD security scanning tools", "infrastructure monitoring system",
|
| 247 |
"industrial control system security", "SCADA system security", "IoT device security",
|
| 248 |
"edge computing security", "smart city infrastructure security", "medical device network",
|
| 249 |
"automotive system security", "home automation security", "sensor security",
|
| 250 |
+
"industrial protocol gateway", "edge security monitoring system", "IoT device firmware repository",
|
| 251 |
+
"enterprise mobile device fleet", "mobile app store backend", "mobile device management system",
|
| 252 |
+
"mobile banking infrastructure", "mobile certificate authority", "mobile security scanning service",
|
| 253 |
+
"BYOD policy enforcement system", "mobile endpoint detection system", "mobile app security testing platform",
|
| 254 |
+
"mobile device fingerprinting database", "mobile phishing detection system", "mobile app code signing service",
|
| 255 |
+
"network segmentation firewall", "VPN concentrator", "DNS authoritative server",
|
| 256 |
+
"BGP route reflector", "wireless access point controller", "network monitoring system",
|
| 257 |
+
"traffic analysis platform", "network security scanning tool", "protocol analysis system",
|
| 258 |
+
"network infrastructure management", "security information system", "network forensics platform"
|
| 259 |
]
|
| 260 |
ADVERSARIAL_SAFE_OUTCOMES = [
|
| 261 |
"all tests passed, security posture confirmed", "the configuration was hardened as per policy",
|
|
|
|
| 301 |
'backup completed': -20, 'schema migration successful': -15, 'network policy updated': -10
|
| 302 |
}
|
| 303 |
|
| 304 |
+
# --- Funções de Geração de Dados de TREINAMENTO (Base Sólida) ---
|
| 305 |
def generate_event_text_for_training(is_risk: bool) -> tuple[str, float]:
|
| 306 |
if is_risk:
|
| 307 |
actor = random.choice(ADVERSARIAL_RISK_ACTORS)
|
|
|
|
| 354 |
conn.close()
|
| 355 |
print("Banco de dados populado inicialmente com sucesso.")
|
| 356 |
|
| 357 |
+
# --- Funções de Embedding e Treinamento ---
|
| 358 |
+
def init_sbert_worker():
|
| 359 |
+
global model_base
|
| 360 |
+
if model_base is None:
|
| 361 |
+
print(f"Processo worker {os.getpid()} carregando o modelo {MODEL_NAME}...")
|
| 362 |
+
model_base = SentenceTransformer(MODEL_NAME)
|
| 363 |
+
torch.set_num_threads(1)
|
| 364 |
+
|
| 365 |
+
def extract_embeddings_batch_worker(texts: list[str]) -> list[list[float]]:
|
| 366 |
+
global model_base
|
| 367 |
+
if model_base is None:
|
| 368 |
+
raise RuntimeError("SentenceTransformer não foi inicializado no worker.")
|
| 369 |
+
|
| 370 |
+
embeddings = model_base.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
| 371 |
+
return embeddings.tolist()
|
| 372 |
+
|
| 373 |
def train_and_save_all_models():
|
| 374 |
+
global mlp_regressor, scaler, tfidf_vectorizer, tfidf_regressor # Model_base é para os workers, não para o principal
|
| 375 |
print("Iniciando o treinamento de todos os modelos a partir do banco de dados...")
|
| 376 |
|
| 377 |
conn = sqlite3.connect(DB_NAME)
|
|
|
|
| 386 |
train_texts = [row[0] for row in train_data]
|
| 387 |
y_train = np.array([row[1] for row in train_data])
|
| 388 |
|
| 389 |
+
# --- Cabeça 1: Embedding Profundo (MLPRegressor) ---
|
| 390 |
print("1. Treinando modelo de Embedding Profundo (MLPRegressor)...")
|
| 391 |
X_train_embeddings = []
|
| 392 |
|
|
|
|
| 426 |
|
| 427 |
print(" ... modelo de Embedding Profundo treinado.")
|
| 428 |
|
| 429 |
+
# --- Cabeça 2: Vetorial Clássico (TF-IDF) ---
|
| 430 |
print("2. Treinando modelo Vetorial Clássico (TF-IDF + Ridge)...")
|
| 431 |
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), min_df=5, max_df=0.7, max_features=10000)
|
| 432 |
X_train_tfidf = tfidf_vectorizer.fit_transform(train_texts)
|