File size: 2,733 Bytes
9cb3002 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """
Evaluation utilities for comparing trained vs random agents.
"""
import numpy as np
import pandas as pd
from typing import List, Dict, Optional
from training.config import TrainingConfig
from training.train import train, run_random_baseline
from utils.visualization import (
plot_reward_curve,
plot_grade_progression,
plot_comparison_table,
)
def evaluate(
config: Optional[TrainingConfig] = None,
trained_metrics: Optional[List[Dict]] = None,
baseline_episodes: int = 10,
df: Optional[pd.DataFrame] = None,
) -> Dict:
"""
Run full evaluation: train agent, run random baseline, compare, and plot.
Args:
config: Training configuration (uses default if None).
trained_metrics: Pre-computed training metrics (skips training if provided).
baseline_episodes: Number of random baseline episodes.
df: Optional dataframe for the environment.
Returns:
Evaluation results dict.
"""
if config is None:
config = TrainingConfig()
# Run training if needed
if trained_metrics is None:
print("Running training...")
trained_metrics = train(config, df=df)
# Run random baseline
print(f"\nRunning random baseline ({baseline_episodes} episodes)...")
baseline_metrics = run_random_baseline(config, df=df, num_episodes=baseline_episodes)
# Print comparison
print(f"\n{'='*60}")
print("EVALUATION RESULTS")
print(f"{'='*60}")
def avg(metrics, key):
return np.mean([m[key] for m in metrics])
print(f"\n{'Metric':<20} {'Random':>12} {'Trained':>12} {'Improvement':>14}")
print("-" * 60)
for key, label in [
("total_reward", "Avg Reward"),
("final_grade", "Avg Grade"),
("pnl_pct", "Avg PnL %"),
("max_drawdown", "Avg Max DD"),
("sharpe_ratio", "Avg Sharpe"),
]:
r = avg(baseline_metrics, key)
t = avg(trained_metrics, key)
imp = t - r
sign = "+" if imp > 0 else ""
print(f" {label:<18} {r:>12.4f} {t:>12.4f} {sign}{imp:>13.4f}")
# Generate plots
print("\nGenerating plots...")
plot_reward_curve(trained_metrics, baseline_metrics)
plot_grade_progression(trained_metrics, baseline_metrics)
plot_comparison_table(trained_metrics, baseline_metrics)
results = {
"trained_metrics": trained_metrics,
"baseline_metrics": baseline_metrics,
"trained_avg_grade": avg(trained_metrics, "final_grade"),
"baseline_avg_grade": avg(baseline_metrics, "final_grade"),
"grade_improvement": avg(trained_metrics, "final_grade") - avg(baseline_metrics, "final_grade"),
}
return results
if __name__ == "__main__":
evaluate()
|