| |
| """ |
| Telemetry Knowledge-Graph Classifier β Training Pipeline |
| ========================================================= |
| Generates synthetic telemetry data, trains the dual-model ensemble |
| (River ARF + LightGBM), builds the knowledge graph, and saves |
| all artifacts for inference. |
| |
| Usage: |
| python train.py [--samples N] [--output DIR] |
| """ |
| import sys |
| import os |
| import time |
| import json |
| import argparse |
| import random |
| import numpy as np |
| from collections import Counter |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from telemetry_kg.synthetic_data import generate_stream, STATES |
| from telemetry_kg.features import FeatureEngineer |
| from telemetry_kg.knowledge_graph import build_default_kg |
| from telemetry_kg.models import DualModelEnsemble |
|
|
|
|
| def train( |
| n_samples: int = 30000, |
| output_dir: str = "/app/model_artifacts", |
| seed: int = 42, |
| lgb_retrain_interval: int = 3000, |
| eval_interval: int = 5000, |
| ): |
| """ |
| Full training pipeline: |
| 1. Generate synthetic telemetry stream |
| 2. Feature-engineer each sample |
| 3. Online train ARF (learn_one per sample) |
| 4. Periodically retrain LightGBM |
| 5. Build knowledge graph from observations |
| 6. Save all artifacts |
| """ |
| print("=" * 70) |
| print(" Telemetry Knowledge-Graph Classifier β Training") |
| print("=" * 70) |
| print(f" Samples: {n_samples:,}") |
| print(f" Output: {output_dir}") |
| print(f" Seed: {seed}") |
| print(f" LGB retrain every: {lgb_retrain_interval:,} samples") |
| print("=" * 70) |
| print() |
|
|
| random.seed(seed) |
| np.random.seed(seed) |
|
|
| |
| print("[1/5] Initializing components...") |
| feature_eng = FeatureEngineer() |
| kg = build_default_kg() |
| ensemble = DualModelEnsemble( |
| arf_n_models=10, |
| arf_max_depth=8, |
| lgb_n_trees=100, |
| lgb_num_leaves=31, |
| model_alpha=0.7, |
| prior_beta=0.3, |
| correction_amplify=3, |
| buffer_capacity=2000, |
| kg=kg, |
| ) |
|
|
| |
| print(f"[2/5] Generating {n_samples:,} synthetic telemetry samples...") |
| t0 = time.time() |
| stream = generate_stream(n_total=n_samples, seed=seed) |
| data_gen_time = time.time() - t0 |
| print(f" Data generated in {data_gen_time:.2f}s") |
| print() |
|
|
| |
| print("[3/5] Online training (streaming through samples)...") |
| t_start = time.time() |
|
|
| |
| all_records = [] |
| all_labels = [] |
| for record in stream: |
| label = record.pop("label") |
| all_records.append(record) |
| all_labels.append(label) |
|
|
| |
| first_features = feature_eng.transform(all_records[0]) |
| feature_names = sorted(first_features.keys()) |
| ensemble.set_feature_names(feature_names) |
| print(f" Feature dimension: {len(feature_names)}") |
| print() |
|
|
| |
| |
| |
| label_dist = Counter() |
| running_correct = 0 |
| running_total = 0 |
|
|
| for i, (record, true_label) in enumerate(zip(all_records, all_labels)): |
| |
| features = feature_eng.transform(record) |
|
|
| |
| predicted, confidence, probs = ensemble.predict(features, raw_telemetry=record) |
|
|
| |
| |
| |
| auto_label = true_label |
| is_correction = (predicted != true_label) and random.random() < 0.8 |
|
|
| |
| ensemble.learn( |
| features=features, |
| true_label=auto_label, |
| raw_telemetry=record, |
| is_correction=is_correction, |
| confidence=0.5 if not is_correction else 1.0, |
| ) |
|
|
| label_dist[true_label] += 1 |
| running_total += 1 |
| if predicted == true_label: |
| running_correct += 1 |
|
|
| |
| if (i + 1) % eval_interval == 0: |
| acc = running_correct / running_total if running_total > 0 else 0 |
| elapsed = time.time() - t_start |
| samples_per_sec = (i + 1) / elapsed |
| print(f" [{i+1:>6,}/{n_samples:,}] " |
| f"Online Acc: {acc:.3f} | " |
| f"Buffer: {len(ensemble.replay_buffer):>5,} | " |
| f"Retrains: {ensemble.n_retrains} | " |
| f"Speed: {samples_per_sec:,.0f} samples/s") |
|
|
| |
| if (i + 1) % lgb_retrain_interval == 0 and i > 0: |
| success = ensemble.retrain_lgbm(feature_names) |
| if success: |
| print(f" >>> LightGBM retrained (#{ensemble.n_retrains}) | " |
| f"ARF weight: {ensemble.w_arf:.3f} | LGB weight: {ensemble.w_lgb:.3f}") |
|
|
| total_time = time.time() - t_start |
| print() |
| print(f" Training complete in {total_time:.1f}s") |
| print(f" Average speed: {n_samples / total_time:,.0f} samples/sec") |
| print() |
|
|
| |
| print("[4/5] Final evaluation...") |
| metrics = ensemble.get_metrics() |
| print(f" Total predictions: {metrics['total_predictions']:,}") |
| print(f" Total corrections: {metrics['total_corrections']:,}") |
| print(f" Overall accuracy: {metrics['accuracy']:.4f}") |
| print(f" LightGBM retrains: {metrics['n_retrains']}") |
| print(f" Replay buffer size: {metrics['replay_buffer_size']:,}") |
| print(f" Ensemble weights: ARF={metrics['ensemble_weights']['arf']:.3f}, " |
| f"LGBM={metrics['ensemble_weights']['lgbm']:.3f}") |
| print() |
| print(" Per-class accuracy:") |
| for state, acc in sorted(metrics["per_class_accuracy"].items()): |
| print(f" {state:<15s}: {acc:.4f}") |
| print() |
| print(" Knowledge Graph:") |
| kg_stats = metrics["kg_stats"] |
| print(f" Nodes: {kg_stats['nodes']}") |
| print(f" Edges: {kg_stats['edges']}") |
| print(f" Tracked processes: {kg_stats['tracked_processes']}") |
| print(f" Override rules: {kg_stats['override_rules']}") |
| print() |
|
|
| |
| print(f"[5/5] Saving model artifacts to {output_dir}...") |
| ensemble.save(output_dir) |
|
|
| |
| fe_state_path = os.path.join(output_dir, "feature_engineer_state.json") |
| feature_eng.save_state(fe_state_path) |
| print(f" Feature engineer state saved to {fe_state_path}") |
|
|
| |
| meta = { |
| "n_samples": n_samples, |
| "seed": seed, |
| "feature_dim": len(feature_names), |
| "feature_names": feature_names, |
| "training_time_seconds": total_time, |
| "label_distribution": dict(label_dist), |
| "final_metrics": metrics, |
| "model_config": { |
| "arf_n_models": 10, |
| "arf_max_depth": 8, |
| "lgb_n_trees": 100, |
| "lgb_num_leaves": 31, |
| "model_alpha": 0.7, |
| "prior_beta": 0.3, |
| }, |
| } |
| with open(os.path.join(output_dir, "training_meta.json"), "w") as f: |
| json.dump(meta, f, indent=2, default=str) |
|
|
| print() |
| print("=" * 70) |
| print(" TRAINING COMPLETE") |
| print("=" * 70) |
| print(f" Artifacts saved to: {output_dir}/") |
| print(f" βββ arf_model.pkl (River ARF online model)") |
| print(f" βββ lgbm_model.txt (LightGBM batch model)") |
| print(f" βββ ensemble_state.json (weights & config)") |
| print(f" βββ feature_names.json (feature ordering)") |
| print(f" βββ replay_buffer.pkl (experience replay)") |
| print(f" βββ knowledge_graph.json (KG with learned priors)") |
| print(f" βββ training_meta.json (training metadata)") |
| print("=" * 70) |
|
|
| return ensemble |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train the Telemetry KG Classifier") |
| parser.add_argument("--samples", type=int, default=30000, |
| help="Number of synthetic samples to generate") |
| parser.add_argument("--output", type=str, default="/app/model_artifacts", |
| help="Output directory for model artifacts") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument("--lgb-retrain", type=int, default=3000, |
| help="Retrain LightGBM every N samples") |
| args = parser.parse_args() |
|
|
| train( |
| n_samples=args.samples, |
| output_dir=args.output, |
| seed=args.seed, |
| lgb_retrain_interval=args.lgb_retrain, |
| ) |
|
|