Niansuh1's picture
Upload train.py with huggingface_hub
ef3e281 verified
#!/usr/bin/env python3
"""
Telemetry Knowledge-Graph Classifier β€” Training Pipeline
=========================================================
Generates synthetic telemetry data, trains the dual-model ensemble
(River ARF + LightGBM), builds the knowledge graph, and saves
all artifacts for inference.
Usage:
python train.py [--samples N] [--output DIR]
"""
import sys
import os
import time
import json
import argparse
import random
import numpy as np
from collections import Counter
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from telemetry_kg.synthetic_data import generate_stream, STATES
from telemetry_kg.features import FeatureEngineer
from telemetry_kg.knowledge_graph import build_default_kg
from telemetry_kg.models import DualModelEnsemble
def train(
n_samples: int = 30000,
output_dir: str = "/app/model_artifacts",
seed: int = 42,
lgb_retrain_interval: int = 3000,
eval_interval: int = 5000,
):
"""
Full training pipeline:
1. Generate synthetic telemetry stream
2. Feature-engineer each sample
3. Online train ARF (learn_one per sample)
4. Periodically retrain LightGBM
5. Build knowledge graph from observations
6. Save all artifacts
"""
print("=" * 70)
print(" Telemetry Knowledge-Graph Classifier β€” Training")
print("=" * 70)
print(f" Samples: {n_samples:,}")
print(f" Output: {output_dir}")
print(f" Seed: {seed}")
print(f" LGB retrain every: {lgb_retrain_interval:,} samples")
print("=" * 70)
print()
random.seed(seed)
np.random.seed(seed)
# ── 1. Initialize components ────────────────────────────────────
print("[1/5] Initializing components...")
feature_eng = FeatureEngineer()
kg = build_default_kg()
ensemble = DualModelEnsemble(
arf_n_models=10,
arf_max_depth=8,
lgb_n_trees=100,
lgb_num_leaves=31,
model_alpha=0.7,
prior_beta=0.3,
correction_amplify=3,
buffer_capacity=2000,
kg=kg,
)
# ── 2. Generate synthetic data ──────────────────────────────────
print(f"[2/5] Generating {n_samples:,} synthetic telemetry samples...")
t0 = time.time()
stream = generate_stream(n_total=n_samples, seed=seed)
data_gen_time = time.time() - t0
print(f" Data generated in {data_gen_time:.2f}s")
print()
# ── 3. Online training loop ─────────────────────────────────────
print("[3/5] Online training (streaming through samples)...")
t_start = time.time()
# Collect all data first for reproducibility
all_records = []
all_labels = []
for record in stream:
label = record.pop("label")
all_records.append(record)
all_labels.append(label)
# Get feature names from first sample
first_features = feature_eng.transform(all_records[0])
feature_names = sorted(first_features.keys())
ensemble.set_feature_names(feature_names)
print(f" Feature dimension: {len(feature_names)}")
print()
# Phase-based training with simulated "auto-labeling"
# In real usage, labels come from user overrides + implicit signals
# Here we simulate: 80% auto-labeled (high confidence), 20% "corrections" (from initial wrong predictions)
label_dist = Counter()
running_correct = 0
running_total = 0
for i, (record, true_label) in enumerate(zip(all_records, all_labels)):
# Feature engineering
features = feature_eng.transform(record)
# Predict (before learning β€” to measure online accuracy)
predicted, confidence, probs = ensemble.predict(features, raw_telemetry=record)
# Simulate user interaction:
# - 85% of the time: label is "auto-confirmed" (user doesn't override)
# - 15% of the time: user explicitly corrects (simulated from initial mispredictions)
auto_label = true_label # In real system, this comes from implicit signals
is_correction = (predicted != true_label) and random.random() < 0.8
# Learn
ensemble.learn(
features=features,
true_label=auto_label,
raw_telemetry=record,
is_correction=is_correction,
confidence=0.5 if not is_correction else 1.0,
)
label_dist[true_label] += 1
running_total += 1
if predicted == true_label:
running_correct += 1
# ── Progress logging ────────────────────────────────────────
if (i + 1) % eval_interval == 0:
acc = running_correct / running_total if running_total > 0 else 0
elapsed = time.time() - t_start
samples_per_sec = (i + 1) / elapsed
print(f" [{i+1:>6,}/{n_samples:,}] "
f"Online Acc: {acc:.3f} | "
f"Buffer: {len(ensemble.replay_buffer):>5,} | "
f"Retrains: {ensemble.n_retrains} | "
f"Speed: {samples_per_sec:,.0f} samples/s")
# ── Periodic LightGBM retrain ───────────────────────────────
if (i + 1) % lgb_retrain_interval == 0 and i > 0:
success = ensemble.retrain_lgbm(feature_names)
if success:
print(f" >>> LightGBM retrained (#{ensemble.n_retrains}) | "
f"ARF weight: {ensemble.w_arf:.3f} | LGB weight: {ensemble.w_lgb:.3f}")
total_time = time.time() - t_start
print()
print(f" Training complete in {total_time:.1f}s")
print(f" Average speed: {n_samples / total_time:,.0f} samples/sec")
print()
# ── 4. Final evaluation ─────────────────────────────────────────
print("[4/5] Final evaluation...")
metrics = ensemble.get_metrics()
print(f" Total predictions: {metrics['total_predictions']:,}")
print(f" Total corrections: {metrics['total_corrections']:,}")
print(f" Overall accuracy: {metrics['accuracy']:.4f}")
print(f" LightGBM retrains: {metrics['n_retrains']}")
print(f" Replay buffer size: {metrics['replay_buffer_size']:,}")
print(f" Ensemble weights: ARF={metrics['ensemble_weights']['arf']:.3f}, "
f"LGBM={metrics['ensemble_weights']['lgbm']:.3f}")
print()
print(" Per-class accuracy:")
for state, acc in sorted(metrics["per_class_accuracy"].items()):
print(f" {state:<15s}: {acc:.4f}")
print()
print(" Knowledge Graph:")
kg_stats = metrics["kg_stats"]
print(f" Nodes: {kg_stats['nodes']}")
print(f" Edges: {kg_stats['edges']}")
print(f" Tracked processes: {kg_stats['tracked_processes']}")
print(f" Override rules: {kg_stats['override_rules']}")
print()
# ── 5. Save artifacts ───────────────────────────────────────────
print(f"[5/5] Saving model artifacts to {output_dir}...")
ensemble.save(output_dir)
# Save FeatureEngineer state (critical for inference-time normalization)
fe_state_path = os.path.join(output_dir, "feature_engineer_state.json")
feature_eng.save_state(fe_state_path)
print(f" Feature engineer state saved to {fe_state_path}")
# Save training metadata
meta = {
"n_samples": n_samples,
"seed": seed,
"feature_dim": len(feature_names),
"feature_names": feature_names,
"training_time_seconds": total_time,
"label_distribution": dict(label_dist),
"final_metrics": metrics,
"model_config": {
"arf_n_models": 10,
"arf_max_depth": 8,
"lgb_n_trees": 100,
"lgb_num_leaves": 31,
"model_alpha": 0.7,
"prior_beta": 0.3,
},
}
with open(os.path.join(output_dir, "training_meta.json"), "w") as f:
json.dump(meta, f, indent=2, default=str)
print()
print("=" * 70)
print(" TRAINING COMPLETE")
print("=" * 70)
print(f" Artifacts saved to: {output_dir}/")
print(f" β”œβ”€β”€ arf_model.pkl (River ARF online model)")
print(f" β”œβ”€β”€ lgbm_model.txt (LightGBM batch model)")
print(f" β”œβ”€β”€ ensemble_state.json (weights & config)")
print(f" β”œβ”€β”€ feature_names.json (feature ordering)")
print(f" β”œβ”€β”€ replay_buffer.pkl (experience replay)")
print(f" β”œβ”€β”€ knowledge_graph.json (KG with learned priors)")
print(f" └── training_meta.json (training metadata)")
print("=" * 70)
return ensemble
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the Telemetry KG Classifier")
parser.add_argument("--samples", type=int, default=30000,
help="Number of synthetic samples to generate")
parser.add_argument("--output", type=str, default="/app/model_artifacts",
help="Output directory for model artifacts")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--lgb-retrain", type=int, default=3000,
help="Retrain LightGBM every N samples")
args = parser.parse_args()
train(
n_samples=args.samples,
output_dir=args.output,
seed=args.seed,
lgb_retrain_interval=args.lgb_retrain,
)