File size: 3,705 Bytes
0162f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Manually train the ensemble model
Run this to test model training or manually trigger retraining
"""
import sys
from pathlib import Path

# Add parent directory to path
parent_dir = str(Path(__file__).parent.parent)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from SelfTrainService.trainer import ModelTrainer
from SelfTrainService.data_store import ScheduleDataStore
from DataService.metro_data_generator import MetroDataGenerator
from DataService.schedule_optimizer import MetroScheduleOptimizer
import json


def generate_sample_data(num_schedules: int = 150):
    """Generate sample schedule data for training"""
    print(f"Generating {num_schedules} sample schedules...")
    from datetime import datetime
    
    data_store = ScheduleDataStore()
    
    for i in range(num_schedules):
        if (i + 1) % 10 == 0:
            print(f"  Generated {i + 1}/{num_schedules}")
        
        # Generate schedule with varying parameters
        num_trains = 25 + (i % 15)  # 25-40 trains
        generator = MetroDataGenerator(num_trains=num_trains)
        route = generator.generate_route()
        train_health = generator.generate_train_health_statuses()
        
        optimizer = MetroScheduleOptimizer(
            date=datetime.now().strftime("%Y-%m-%d"),
            num_trains=num_trains,
            route=route,
            train_health=train_health
        )
        schedule = optimizer.optimize_schedule()
        
        # Save schedule
        data_store.save_schedule(schedule.model_dump())
    
    print(f"✓ Generated {num_schedules} schedules")


def main():
    """Train the ensemble model"""
    print("=" * 60)
    print("Multi-Model Ensemble Training")
    print("=" * 60)
    
    # Check if we have enough data
    data_store = ScheduleDataStore()
    count = data_store.count_schedules()
    
    print(f"\nCurrent data: {count} schedules")
    
    if count < 100:
        print(f"Need at least 100 schedules for training")
        generate_sample_data(150)
    
    # Initialize trainer
    print("\nInitializing model trainer...")
    trainer = ModelTrainer()
    
    # Train models
    print("\nTraining ensemble models...")
    print("Models: gradient_boosting, random_forest, xgboost, lightgbm, catboost")
    print()
    
    result = trainer.train(force=True)
    
    if result["success"]:
        print("\n" + "=" * 60)
        print("Training Complete!")
        print("=" * 60)
        print(f"\nModels trained: {', '.join(result['models_trained'])}")
        print(f"Best model: {result['best_model']}")
        print(f"Samples used: {result['samples_used']}")
        print(f"\nEnsemble Weights:")
        for model, weight in result['ensemble_weights'].items():
            print(f"  {model}: {weight:.4f}")
        
        print(f"\nModel Performance:")
        for model, metrics in result['metrics'].items():
            print(f"\n{model}:")
            print(f"  Test R²: {metrics['test_r2']:.4f}")
            print(f"  Test RMSE: {metrics['test_rmse']:.4f}")
        
        # Save summary
        summary_path = Path("models/training_summary.json")
        summary_path.parent.mkdir(parents=True, exist_ok=True)
        with open(summary_path, 'w') as f:
            json.dump(result, f, indent=2, default=str)
        
        print(f"\n✓ Training summary saved to {summary_path}")
    else:
        print(f"\n✗ Training failed: {result.get('reason', result.get('error'))}")
    
    # Show model info
    print("\n" + "=" * 60)
    print("Current Model Info")
    print("=" * 60)
    info = trainer.get_model_info()
    print(json.dumps(info, indent=2, default=str))


if __name__ == "__main__":
    main()