chaos4455 commited on
Commit
f7afb2c
·
verified ·
1 Parent(s): 74dd954

Update train_and_save_models.py

Browse files
Files changed (1) hide show
  1. train_and_save_models.py +37 -16
train_and_save_models.py CHANGED
@@ -10,9 +10,9 @@ from pathlib import Path
10
  from concurrent.futures import ProcessPoolExecutor, as_completed
11
  import numpy as np
12
  import multiprocessing
13
- import joblib # Importado para salvar e carregar modelos
14
 
15
- # REMOVIDOS Flask e CORS, pois este script é apenas para treinamento local.
16
  # from flask import Flask, request, jsonify
17
  # from flask_cors import CORS
18
 
@@ -46,11 +46,13 @@ torch.manual_seed(RANDOM_SEED)
46
  if torch.cuda.is_available():
47
  torch.cuda.manual_seed_all(RANDOM_SEED)
48
 
49
- model_base = None
50
  mlp_regressor, scaler = None, None
51
  tfidf_vectorizer, tfidf_regressor = None, None
52
 
53
- # --- Vocabulário Expandido (compartilhado) ---
 
 
54
  ADVERSARIAL_RISK_ACTORS = [
55
  "Unsandboxed process", "Leaked API key", "Misconfigured service account", "Shadow IT application",
56
  "Dormant user account", "Ransomware payload", "Phishing attempt", "Insider threat",
@@ -237,23 +239,23 @@ ADVERSARIAL_SAFE_TARGETS = [
237
  "Kubernetes cluster security posture", "Docker container security configuration", "AWS cloud infrastructure",
238
  "Azure cloud resources", "GCP cloud services", "container orchestration security",
239
  "serverless function security", "cloud API security", "microservice security architecture",
240
- "container registry security", "cloud logging security", "infrastructure as code security",
241
  "Git repository security", "CI/CD pipeline security", "Docker image security",
242
  "artifact repository security", "infrastructure provisioning security", "secret management vault",
243
  "code signing certificate store", "dependency management system", "deployment automation platform",
244
- "build environment security", "CI/CD security scanning", "infrastructure monitoring security",
245
  "industrial control system security", "SCADA system security", "IoT device security",
246
  "edge computing security", "smart city infrastructure security", "medical device network",
247
  "automotive system security", "home automation security", "sensor security",
248
- "industrial protocol security", "edge gateway security", "IoT device management security",
249
- "enterprise mobile device security", "mobile app security", "mobile device management security",
250
- "mobile banking security", "mobile certificate security", "mobile security scanning",
251
- "BYOD policy security", "mobile endpoint security", "mobile app store security",
252
- "mobile device fingerprinting security", "mobile phishing protection", "mobile security testing",
253
- "network segmentation security", "firewall security", "VPN security",
254
- "DNS security", "BGP routing security", "wireless network security",
255
- "Bluetooth security", "NFC security", "network monitoring security",
256
- "traffic analysis security", "protocol security", "network infrastructure security"
257
  ]
258
  ADVERSARIAL_SAFE_OUTCOMES = [
259
  "all tests passed, security posture confirmed", "the configuration was hardened as per policy",
@@ -299,6 +301,7 @@ LOW_RISK_KEYWORDS = {
299
  'backup completed': -20, 'schema migration successful': -15, 'network policy updated': -10
300
  }
301
 
 
302
  def generate_event_text_for_training(is_risk: bool) -> tuple[str, float]:
303
  if is_risk:
304
  actor = random.choice(ADVERSARIAL_RISK_ACTORS)
@@ -351,8 +354,24 @@ def populate_database_initial():
351
  conn.close()
352
  print("Banco de dados populado inicialmente com sucesso.")
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  def train_and_save_all_models():
355
- global mlp_regressor, scaler, tfidf_vectorizer, tfidf_regressor, model_base
356
  print("Iniciando o treinamento de todos os modelos a partir do banco de dados...")
357
 
358
  conn = sqlite3.connect(DB_NAME)
@@ -367,6 +386,7 @@ def train_and_save_all_models():
367
  train_texts = [row[0] for row in train_data]
368
  y_train = np.array([row[1] for row in train_data])
369
 
 
370
  print("1. Treinando modelo de Embedding Profundo (MLPRegressor)...")
371
  X_train_embeddings = []
372
 
@@ -406,6 +426,7 @@ def train_and_save_all_models():
406
 
407
  print(" ... modelo de Embedding Profundo treinado.")
408
 
 
409
  print("2. Treinando modelo Vetorial Clássico (TF-IDF + Ridge)...")
410
  tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), min_df=5, max_df=0.7, max_features=10000)
411
  X_train_tfidf = tfidf_vectorizer.fit_transform(train_texts)
 
10
  from concurrent.futures import ProcessPoolExecutor, as_completed
11
  import numpy as np
12
  import multiprocessing
13
+ import joblib # Importado para salvar os modelos
14
 
15
+ # REMOVIDAS importações Flask e CORS, pois este script é apenas para treinamento local.
16
  # from flask import Flask, request, jsonify
17
  # from flask_cors import CORS
18
 
 
46
  if torch.cuda.is_available():
47
  torch.cuda.manual_seed_all(RANDOM_SEED)
48
 
49
+ model_base = None # Será inicializado durante o treinamento
50
  mlp_regressor, scaler = None, None
51
  tfidf_vectorizer, tfidf_regressor = None, None
52
 
53
+ # --- Vocabulário Expandido (compartilhado para geração de dados) ---
54
+ # Estas listas são usadas SOMENTE no trainer.py para gerar os dados de treinamento.
55
+ # Elas serão duplicadas (ou uma versão reduzida) no app.py se o "Gerar Evento Aleatório" for mantido.
56
  ADVERSARIAL_RISK_ACTORS = [
57
  "Unsandboxed process", "Leaked API key", "Misconfigured service account", "Shadow IT application",
58
  "Dormant user account", "Ransomware payload", "Phishing attempt", "Insider threat",
 
239
  "Kubernetes cluster security posture", "Docker container security configuration", "AWS cloud infrastructure",
240
  "Azure cloud resources", "GCP cloud services", "container orchestration security",
241
  "serverless function security", "cloud API security", "microservice security architecture",
242
+ "container registry security", "cloud logging security", "infrastructure as code repository",
243
  "Git repository security", "CI/CD pipeline security", "Docker image security",
244
  "artifact repository security", "infrastructure provisioning security", "secret management vault",
245
  "code signing certificate store", "dependency management system", "deployment automation platform",
246
+ "build environment with elevated privileges", "CI/CD security scanning tools", "infrastructure monitoring system",
247
  "industrial control system security", "SCADA system security", "IoT device security",
248
  "edge computing security", "smart city infrastructure security", "medical device network",
249
  "automotive system security", "home automation security", "sensor security",
250
+ "industrial protocol gateway", "edge security monitoring system", "IoT device firmware repository",
251
+ "enterprise mobile device fleet", "mobile app store backend", "mobile device management system",
252
+ "mobile banking infrastructure", "mobile certificate authority", "mobile security scanning service",
253
+ "BYOD policy enforcement system", "mobile endpoint detection system", "mobile app security testing platform",
254
+ "mobile device fingerprinting database", "mobile phishing detection system", "mobile app code signing service",
255
+ "network segmentation firewall", "VPN concentrator", "DNS authoritative server",
256
+ "BGP route reflector", "wireless access point controller", "network monitoring system",
257
+ "traffic analysis platform", "network security scanning tool", "protocol analysis system",
258
+ "network infrastructure management", "security information system", "network forensics platform"
259
  ]
260
  ADVERSARIAL_SAFE_OUTCOMES = [
261
  "all tests passed, security posture confirmed", "the configuration was hardened as per policy",
 
301
  'backup completed': -20, 'schema migration successful': -15, 'network policy updated': -10
302
  }
303
 
304
+ # --- Funções de Geração de Dados de TREINAMENTO (Base Sólida) ---
305
  def generate_event_text_for_training(is_risk: bool) -> tuple[str, float]:
306
  if is_risk:
307
  actor = random.choice(ADVERSARIAL_RISK_ACTORS)
 
354
  conn.close()
355
  print("Banco de dados populado inicialmente com sucesso.")
356
 
357
+ # --- Funções de Embedding e Treinamento ---
358
+ def init_sbert_worker():
359
+ global model_base
360
+ if model_base is None:
361
+ print(f"Processo worker {os.getpid()} carregando o modelo {MODEL_NAME}...")
362
+ model_base = SentenceTransformer(MODEL_NAME)
363
+ torch.set_num_threads(1)
364
+
365
+ def extract_embeddings_batch_worker(texts: list[str]) -> list[list[float]]:
366
+ global model_base
367
+ if model_base is None:
368
+ raise RuntimeError("SentenceTransformer não foi inicializado no worker.")
369
+
370
+ embeddings = model_base.encode(texts, convert_to_numpy=True, show_progress_bar=False)
371
+ return embeddings.tolist()
372
+
373
  def train_and_save_all_models():
374
+ global mlp_regressor, scaler, tfidf_vectorizer, tfidf_regressor # Model_base é para os workers, não para o principal
375
  print("Iniciando o treinamento de todos os modelos a partir do banco de dados...")
376
 
377
  conn = sqlite3.connect(DB_NAME)
 
386
  train_texts = [row[0] for row in train_data]
387
  y_train = np.array([row[1] for row in train_data])
388
 
389
+ # --- Cabeça 1: Embedding Profundo (MLPRegressor) ---
390
  print("1. Treinando modelo de Embedding Profundo (MLPRegressor)...")
391
  X_train_embeddings = []
392
 
 
426
 
427
  print(" ... modelo de Embedding Profundo treinado.")
428
 
429
+ # --- Cabeça 2: Vetorial Clássico (TF-IDF) ---
430
  print("2. Treinando modelo Vetorial Clássico (TF-IDF + Ridge)...")
431
  tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), min_df=5, max_df=0.7, max_features=10000)
432
  X_train_tfidf = tfidf_vectorizer.fit_transform(train_texts)