File size: 4,631 Bytes
10150dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Plotting utilities for experiment results.

Generates static PNG plots for each experiment:
1. Price series with fair value line
2. Bid-ask spread over time
3. Agent PnL over time
4. Trade volume per tick
"""

import os
from pathlib import Path

try:
    import matplotlib
    matplotlib.use("Agg")  # Non-interactive backend
    import matplotlib.pyplot as plt
    HAS_MPL = True
except ImportError:
    HAS_MPL = False


def plot_experiment(engine, title: str, output_dir: str, fair_value: float = 100.0):
    """
    Generate all experiment plots and save to output_dir.
    Falls back to text summary if matplotlib is not installed.
    """
    if not HAS_MPL:
        print(f"[WARN] matplotlib not installed — skipping plots for {title}")
        _text_summary(engine, title, output_dir)
        return

    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    _plot_price_series(engine, title, fair_value, out)
    _plot_spread(engine, title, out)
    _plot_agent_pnl(engine, title, out)
    _plot_volume(engine, title, out)

    print(f"Plots saved to {out}/")


def _plot_price_series(engine, title: str, fair_value: float, out: Path):
    """Mid price over time with fair value reference line."""
    ticks = [m.tick for m in engine.metrics.tick_history if m.mid_price is not None]
    prices = [m.mid_price for m in engine.metrics.tick_history if m.mid_price is not None]

    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(ticks, prices, linewidth=1.5, color="#2196F3", label="Mid Price")
    ax.axhline(y=fair_value, color="#F44336", linestyle="--", linewidth=1, alpha=0.7, label=f"Fair Value ({fair_value})")
    ax.set_xlabel("Tick")
    ax.set_ylabel("Price")
    ax.set_title(f"{title}\nPrice Series")
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out / "price_series.png", dpi=150)
    plt.close(fig)


def _plot_spread(engine, title: str, out: Path):
    """Bid-ask spread over time."""
    ticks = [m.tick for m in engine.metrics.tick_history if m.spread is not None]
    spreads = [m.spread for m in engine.metrics.tick_history if m.spread is not None]

    fig, ax = plt.subplots(figsize=(12, 4))
    ax.fill_between(ticks, spreads, alpha=0.4, color="#FF9800")
    ax.plot(ticks, spreads, linewidth=1, color="#E65100")
    ax.set_xlabel("Tick")
    ax.set_ylabel("Spread")
    ax.set_title(f"{title}\nBid-Ask Spread")
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out / "spread.png", dpi=150)
    plt.close(fig)


def _plot_agent_pnl(engine, title: str, out: Path):
    """Per-agent PnL over time."""
    # Collect PnL per agent per tick
    agent_data: dict[str, list[tuple[int, float]]] = {}

    for row in engine.agent_pnl_rows:
        key = f"{row['agent_id']} ({row['agent_type']})"
        if key not in agent_data:
            agent_data[key] = []
        agent_data[key].append((row["tick"], row["pnl"]))

    colors = ["#2196F3", "#4CAF50", "#F44336", "#FF9800", "#9C27B0", "#00BCD4", "#795548", "#607D8B"]
    fig, ax = plt.subplots(figsize=(12, 5))

    for i, (label, data) in enumerate(agent_data.items()):
        ticks = [d[0] for d in data]
        pnls = [d[1] for d in data]
        color = colors[i % len(colors)]
        ax.plot(ticks, pnls, linewidth=1.2, label=label, color=color)

    ax.axhline(y=0, color="gray", linestyle="-", linewidth=0.5)
    ax.set_xlabel("Tick")
    ax.set_ylabel("PnL")
    ax.set_title(f"{title}\nAgent PnL")
    ax.legend(fontsize=8, loc="best")
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out / "agent_pnl.png", dpi=150)
    plt.close(fig)


def _plot_volume(engine, title: str, out: Path):
    """Trade volume per tick."""
    ticks = [m.tick for m in engine.metrics.tick_history]
    volumes = [m.volume for m in engine.metrics.tick_history]

    fig, ax = plt.subplots(figsize=(12, 3))
    ax.bar(ticks, volumes, width=1.0, color="#4CAF50", alpha=0.7)
    ax.set_xlabel("Tick")
    ax.set_ylabel("Volume")
    ax.set_title(f"{title}\nTrade Volume per Tick")
    ax.grid(True, alpha=0.3, axis="y")
    fig.tight_layout()
    fig.savefig(out / "volume.png", dpi=150)
    plt.close(fig)


def _text_summary(engine, title: str, output_dir: str):
    """Fallback text summary when matplotlib is unavailable."""
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    summary = engine.metrics.summary()
    with open(out / "summary.txt", "w") as f:
        f.write(f"{title}\n{'=' * len(title)}\n\n")
        for k, v in summary.items():
            f.write(f"{k}: {v}\n")
    print(f"Text summary saved to {out}/summary.txt")