OmidSakaki commited on
Commit
1437e6d
·
verified ·
1 Parent(s): de24124

Create src/visualizers/chart_renderer.py

Browse files
Files changed (1) hide show
  1. src/visualizers/chart_renderer.py +216 -0
src/visualizers/chart_renderer.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ from plotly.subplots import make_subplots
3
+ import numpy as np
4
+
5
+ class ChartRenderer:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def render_price_chart(self, prices, actions=None, current_step=0):
10
+ """Render price chart with actions"""
11
+ fig = go.Figure()
12
+
13
+ if not prices:
14
+ # Return empty figure if no data
15
+ fig.update_layout(
16
+ title="Price Chart - No Data Available",
17
+ xaxis_title="Time Step",
18
+ yaxis_title="Price",
19
+ height=300,
20
+ template="plotly_white"
21
+ )
22
+ return fig
23
+
24
+ # Add price line
25
+ fig.add_trace(go.Scatter(
26
+ x=list(range(len(prices))),
27
+ y=prices,
28
+ mode='lines',
29
+ name='Price',
30
+ line=dict(color='blue', width=2)
31
+ ))
32
+
33
+ # Add action markers if provided
34
+ if actions and len(actions) == len(prices):
35
+ buy_indices = [i for i, action in enumerate(actions) if action == 1]
36
+ sell_indices = [i for i, action in enumerate(actions) if action == 2]
37
+ close_indices = [i for i, action in enumerate(actions) if action == 3]
38
+
39
+ if buy_indices:
40
+ fig.add_trace(go.Scatter(
41
+ x=buy_indices,
42
+ y=[prices[i] for i in buy_indices],
43
+ mode='markers',
44
+ name='Buy',
45
+ marker=dict(color='green', size=10, symbol='triangle-up',
46
+ line=dict(width=2, color='darkgreen'))
47
+ ))
48
+
49
+ if sell_indices:
50
+ fig.add_trace(go.Scatter(
51
+ x=sell_indices,
52
+ y=[prices[i] for i in sell_indices],
53
+ mode='markers',
54
+ name='Sell',
55
+ marker=dict(color='red', size=10, symbol='triangle-down',
56
+ line=dict(width=2, color='darkred'))
57
+ ))
58
+
59
+ if close_indices:
60
+ fig.add_trace(go.Scatter(
61
+ x=close_indices,
62
+ y=[prices[i] for i in close_indices],
63
+ mode='markers',
64
+ name='Close',
65
+ marker=dict(color='orange', size=8, symbol='x',
66
+ line=dict(width=2, color='darkorange'))
67
+ ))
68
+
69
+ fig.update_layout(
70
+ title=f"Price Chart (Step: {current_step})",
71
+ xaxis_title="Time Step",
72
+ yaxis_title="Price",
73
+ height=300,
74
+ showlegend=True,
75
+ template="plotly_white"
76
+ )
77
+
78
+ return fig
79
+
80
+ def create_performance_chart(self, net_worth_history, reward_history, initial_balance):
81
+ """Create portfolio performance chart"""
82
+ fig = make_subplots(
83
+ rows=2, cols=1,
84
+ subplot_titles=['Portfolio Value Over Time', 'Step Rewards'],
85
+ vertical_spacing=0.15
86
+ )
87
+
88
+ if not net_worth_history:
89
+ fig.update_layout(title="No Data Available", height=400)
90
+ return fig
91
+
92
+ # Portfolio value
93
+ fig.add_trace(go.Scatter(
94
+ x=list(range(len(net_worth_history))),
95
+ y=net_worth_history,
96
+ mode='lines+markers',
97
+ name='Net Worth',
98
+ line=dict(color='green', width=3),
99
+ marker=dict(size=4)
100
+ ), row=1, col=1)
101
+
102
+ # Add initial balance reference line
103
+ fig.add_hline(y=initial_balance, line_dash="dash",
104
+ line_color="red", annotation_text="Initial Balance",
105
+ row=1, col=1)
106
+
107
+ # Rewards as bar chart
108
+ if reward_history:
109
+ fig.add_trace(go.Bar(
110
+ x=list(range(len(reward_history))),
111
+ y=reward_history,
112
+ name='Reward',
113
+ marker_color=['green' if r >= 0 else 'red' for r in reward_history],
114
+ opacity=0.7
115
+ ), row=2, col=1)
116
+
117
+ fig.update_layout(height=500, showlegend=False, template="plotly_white")
118
+ fig.update_yaxes(title_text="Value ($)", row=1, col=1)
119
+ fig.update_yaxes(title_text="Reward", row=2, col=1)
120
+ fig.update_xaxes(title_text="Step", row=2, col=1)
121
+
122
+ return fig
123
+
124
+ def create_action_distribution(self, actions):
125
+ """Create action distribution pie chart"""
126
+ fig = go.Figure()
127
+
128
+ if not actions:
129
+ fig.update_layout(title="No Actions Available", height=300)
130
+ return fig
131
+
132
+ action_names = ['Hold', 'Buy', 'Sell', 'Close']
133
+ action_counts = [actions.count(i) for i in range(4)]
134
+
135
+ colors = ['blue', 'green', 'red', 'orange']
136
+
137
+ fig = go.Figure(data=[go.Pie(
138
+ labels=action_names,
139
+ values=action_counts,
140
+ hole=.4,
141
+ marker_colors=colors,
142
+ textinfo='label+percent+value',
143
+ hoverinfo='label+percent+value'
144
+ )])
145
+
146
+ fig.update_layout(
147
+ title="Action Distribution",
148
+ height=350,
149
+ annotations=[dict(text='Actions', x=0.5, y=0.5, font_size=16, showarrow=False)],
150
+ template="plotly_white"
151
+ )
152
+
153
+ return fig
154
+
155
+ def create_training_progress(self, training_history):
156
+ """Create training progress visualization"""
157
+ if not training_history:
158
+ fig = go.Figure()
159
+ fig.update_layout(title="No Training Data Available", height=500)
160
+ return fig
161
+
162
+ episodes = [h['episode'] for h in training_history]
163
+ rewards = [h['reward'] for h in training_history]
164
+ net_worths = [h['net_worth'] for h in training_history]
165
+ losses = [h.get('loss', 0) for h in training_history]
166
+
167
+ fig = make_subplots(
168
+ rows=2, cols=2,
169
+ subplot_titles=['Episode Rewards', 'Portfolio Value',
170
+ 'Training Loss', 'Moving Average Reward (5)'],
171
+ specs=[[{}, {}], [{}, {}]]
172
+ )
173
+
174
+ # Rewards
175
+ fig.add_trace(go.Scatter(
176
+ x=episodes, y=rewards, mode='lines+markers',
177
+ name='Reward', line=dict(color='blue', width=2),
178
+ marker=dict(size=4)
179
+ ), row=1, col=1)
180
+
181
+ # Portfolio value
182
+ fig.add_trace(go.Scatter(
183
+ x=episodes, y=net_worths, mode='lines+markers',
184
+ name='Net Worth', line=dict(color='green', width=2),
185
+ marker=dict(size=4)
186
+ ), row=1, col=2)
187
+
188
+ # Loss
189
+ if any(loss > 0 for loss in losses):
190
+ fig.add_trace(go.Scatter(
191
+ x=episodes, y=losses, mode='lines+markers',
192
+ name='Loss', line=dict(color='red', width=2),
193
+ marker=dict(size=4)
194
+ ), row=2, col=1)
195
+
196
+ # Moving average reward
197
+ if len(rewards) > 5:
198
+ ma_rewards = []
199
+ for i in range(len(rewards)):
200
+ start_idx = max(0, i - 4)
201
+ ma = np.mean(rewards[start_idx:i+1])
202
+ ma_rewards.append(ma)
203
+
204
+ fig.add_trace(go.Scatter(
205
+ x=episodes, y=ma_rewards, mode='lines',
206
+ name='MA Reward (5)', line=dict(color='orange', width=3, dash='dash')
207
+ ), row=2, col=2)
208
+
209
+ fig.update_layout(
210
+ height=600,
211
+ showlegend=True,
212
+ title_text="Training Progress Over Episodes",
213
+ template="plotly_white"
214
+ )
215
+
216
+ return fig