OmidSakaki commited on
Commit
0f09ca5
Β·
verified Β·
1 Parent(s): b72531e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -157
app.py CHANGED
@@ -1,30 +1,38 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- from pathlib import Path
5
- from typing import Dict, Tuple, Any
6
- from loguru import logger
7
- import yaml
8
- from gymnasium import spaces
 
 
 
 
 
 
9
 
10
  class TradingConfig:
11
- def __init__(self):
12
- self.initial_balance = 10000.0
 
13
  self.max_steps = 1000
14
  self.transaction_cost = 0.001
15
- self.risk_level = "Medium"
16
- self.asset_type = "Crypto"
17
  self.learning_rate = 0.0001
18
  self.gamma = 0.99
19
  self.epsilon_start = 1.0
20
  self.epsilon_min = 0.01
21
- self.epsilon_decay = 0.9995
22
  self.batch_size = 32
23
- self.memory_size = 10000
24
- self.target_update = 100
25
 
26
  class AdvancedTradingEnvironment:
27
- def __init__(self, config):
 
 
 
28
  self.initial_balance = config.initial_balance
29
  self.balance = self.initial_balance
30
  self.position = 0.0
@@ -33,147 +41,218 @@ class AdvancedTradingEnvironment:
33
  self.max_steps = config.max_steps
34
  self.price_history = []
35
  self.sentiment_history = []
 
 
 
36
  self._initialize_data()
37
- self.action_space = spaces.Discrete(4)
38
- self.observation_space = spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
39
-
 
 
40
  def _initialize_data(self):
41
- n_points = 100
 
42
  base_price = 100.0
 
43
  for i in range(n_points):
44
- price = base_price + np.sin(i * 0.1) * 10 + np.random.normal(0, 2)
45
- self.price_history.append(max(10.0, price))
46
- sentiment = 0.5 + np.random.normal(0, 0.1)
 
 
 
 
 
47
  self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
 
48
  self.current_price = self.price_history[-1]
49
-
50
- def reset(self):
 
51
  self.balance = self.initial_balance
52
  self.position = 0.0
53
  self.step_count = 0
54
- self.price_history = [100.0 + np.random.normal(0, 5)]
55
- self.sentiment_history = [0.5]
 
 
 
56
  obs = self._get_observation()
57
  info = self._get_info()
 
 
58
  return obs, info
59
-
60
- def step(self, action):
 
61
  self.step_count += 1
62
- price_change = np.random.normal(0, 0.02)
63
- self.current_price = max(10.0, self.current_price * (1 + price_change))
64
- self.price_history.append(self.current_price)
65
 
66
- sentiment_change = np.random.normal(0, 0.05)
67
- new_sentiment = np.clip(self.sentiment_history[-1] + sentiment_change, 0.0, 1.0)
68
- self.sentiment_history.append(new_sentiment)
69
 
 
70
  reward = self._execute_action(action)
71
 
72
- terminated = self.balance <= 0 or self.step_count >= self.max_steps
 
73
  truncated = False
74
 
75
  obs = self._get_observation()
76
  info = self._get_info()
77
 
78
  return obs, reward, terminated, truncated, info
79
-
80
- def _execute_action(self, action):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  reward = 0.0
82
  prev_net_worth = self.balance + self.position * self.current_price
83
 
 
 
 
84
  if action == 1: # Buy
85
- trade_amount = min(self.balance * 0.2, self.balance)
86
- cost = trade_amount
87
- if cost <= self.balance:
88
- self.position += trade_amount / self.current_price
89
- self.balance -= cost
 
 
 
90
 
91
  elif action == 2: # Sell
92
  if self.position > 0:
93
- sell_amount = min(self.position * 0.2, self.position)
94
- proceeds = sell_amount * self.current_price
95
- self.position -= sell_amount
96
  self.balance += proceeds
 
97
 
98
- elif action == 3: # Close
99
  if self.position > 0:
100
- proceeds = self.position * self.current_price
101
  self.balance += proceeds
102
  self.position = 0
 
103
 
104
- net_worth = self.balance + self.position * self.current_price
105
- reward = (net_worth - prev_net_worth) / self.initial_balance * 100
 
 
106
 
107
  return reward
108
-
109
- def _get_observation(self):
110
- recent_prices = self.price_history[-10:] if len(self.price_history) >= 10 else [self.current_price] * 10
111
- recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else [0.5] * 10
 
 
112
 
 
113
  features = [
114
- self.balance / self.initial_balance,
115
- self.position * self.current_price / self.initial_balance,
116
- self.current_price / 100.0,
117
- np.mean(recent_prices) / 100.0,
118
- np.std(recent_prices) / 100.0,
119
- np.mean(recent_sentiments),
120
- np.std(recent_sentiments),
121
- self.step_count / self.max_steps,
122
- 0.0, 0.0, 0.0, 0.0 # Padding
 
 
 
123
  ]
124
 
125
  return np.array(features[:12], dtype=np.float32)
126
-
127
- def _get_info(self):
 
128
  net_worth = self.balance + self.position * self.current_price
129
- return {'net_worth': net_worth}
 
 
 
 
 
 
 
130
 
131
  class DQNAgent:
132
- def __init__(self, state_dim, action_dim, config, device='cpu'):
 
 
 
 
133
  self.device = torch.device(device)
134
- self.q_network = torch.nn.Sequential(
135
- torch.nn.Linear(state_dim, 128),
136
- torch.nn.ReLU(),
137
- torch.nn.Linear(128, 128),
138
- torch.nn.ReLU(),
139
- torch.nn.Linear(128, action_dim)
140
- ).to(self.device)
141
-
142
- self.target_network = torch.nn.Sequential(
143
- torch.nn.Linear(state_dim, 128),
144
- torch.nn.ReLU(),
145
- torch.nn.Linear(128, 128),
146
- torch.nn.ReLU(),
147
- torch.nn.Linear(128, action_dim)
148
- ).to(self.device)
149
 
 
 
 
150
  self.target_network.load_state_dict(self.q_network.state_dict())
151
 
152
  self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
153
  self.memory = deque(maxlen=config.memory_size)
154
- self.gamma = config.gamma
155
  self.epsilon = config.epsilon_start
156
- self.epsilon_min = config.epsilon_min
157
- self.epsilon_decay = config.epsilon_decay
158
- self.batch_size = config.batch_size
159
- self.target_update = config.target_update
160
  self.steps = 0
161
-
162
- def select_action(self, state, training=True):
163
- state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
 
164
  if training and random.random() < self.epsilon:
165
- return random.randint(0, 3)
 
 
166
  with torch.no_grad():
167
- return self.q_network(state).argmax(1).item()
168
-
 
 
 
 
 
 
169
  def store_transition(self, state, action, reward, next_state, done):
170
- self.memory.append((state, action, reward, next_state, done))
171
-
172
- def update(self):
173
- if len(self.memory) < self.batch_size:
 
 
174
  return 0.0
175
 
176
- batch = random.sample(self.memory, self.batch_size)
 
177
  states, actions, rewards, next_states, dones = zip(*batch)
178
 
179
  states = torch.FloatTensor(np.array(states)).to(self.device)
@@ -182,98 +261,275 @@ class DQNAgent:
182
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
183
  dones = torch.FloatTensor(dones).to(self.device)
184
 
 
185
  current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
186
  next_q = self.target_network(next_states).max(1)[0]
187
- target_q = rewards + self.gamma * next_q * (1 - dones)
188
 
189
  loss = torch.nn.MSELoss()(current_q, target_q)
190
 
191
  self.optimizer.zero_grad()
192
  loss.backward()
 
193
  self.optimizer.step()
194
 
195
  self.steps += 1
196
- if self.steps % self.target_update == 0:
197
- self.target_network.load_state_dict(self.q_network.state_dict())
198
 
199
- self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
 
 
200
 
201
  return loss.item()
202
 
203
  class TradingDemo:
 
 
204
  def __init__(self):
205
- self.config = TradingConfig()
206
  self.env = None
207
  self.agent = None
208
  self.device = 'cpu'
209
-
210
- def initialize(self, balance, risk, asset):
211
- self.config.initial_balance = balance
212
- self.config.risk_level = risk
213
- self.config.asset_type = asset
214
- self.env = AdvancedTradingEnvironment(self.config)
215
- self.agent = DQNAgent(12, 4, self.config, self.device)
216
- return "βœ… Initialized!"
217
-
218
- def train(self, episodes):
219
- for ep in range(episodes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  obs, _ = self.env.reset()
221
- total_reward = 0
222
- done = False
223
- while not done:
224
- action = self.agent.select_action(obs)
 
 
 
 
225
  next_obs, reward, done, _, info = self.env.step(action)
226
- self.agent.store_transition(obs, action, reward, next_obs, done)
 
 
 
 
 
227
  obs = next_obs
228
- total_reward += reward
229
- self.agent.update()
230
- yield f"Episode {ep+1}/{episodes} | Reward: {total_reward:.2f}", None
231
- yield "βœ… Training complete!", None
232
-
233
- def simulate(self, steps):
234
- obs, _ = self.env.reset()
235
- prices = []
236
- actions = []
237
- net_worths = []
238
- for _ in range(steps):
239
- action = self.agent.select_action(obs, training=False)
240
- next_obs, reward, done, _, info = self.env.step(action)
241
- prices.append(self.env.current_price)
242
- actions.append(action)
243
- net_worths.append(info['net_worth'])
244
- obs = next_obs
245
- if done:
246
- break
247
-
248
- import plotly.graph_objects as go
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  fig = go.Figure()
250
- fig.add_trace(go.Scatter(y=prices, mode='lines', name='Price'))
251
- fig.add_trace(go.Scatter(y=net_worths, mode='lines', name='Net Worth'))
252
- return "βœ… Simulation complete!", fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
 
254
  demo = TradingDemo()
255
 
256
- with gr.Blocks() as interface:
257
- gr.Markdown("# Trading AI Demo")
 
258
 
 
259
  with gr.Row():
260
- balance = gr.Slider(1000, 50000, 10000, label="Balance")
261
- risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk")
262
- asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset")
263
- init_btn = gr.Button("Initialize")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- status = gr.Textbox(label="Status")
 
 
 
 
266
 
267
- episodes = gr.Number(value=50, label="Episodes")
268
- train_btn = gr.Button("Train")
269
- train_plot = gr.Plot()
 
 
270
 
271
- steps = gr.Number(value=100, label="Simulation Steps")
272
- sim_btn = gr.Button("Simulate")
273
- sim_plot = gr.Plot()
 
 
274
 
275
- init_btn.click(demo.initialize, [balance, risk, asset], status)
276
- train_btn.click(demo.train, episodes, [status, train_plot])
277
- sim_btn.click(demo.simulate, steps, [status, sim_plot])
278
 
279
- interface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ import random
5
+ from collections import deque
6
+ from typing import Tuple, Iterator, Any
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+ import logging
10
+ from datetime import datetime
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
  class TradingConfig:
17
+ """Configuration for trading environment"""
18
+ def __init__(self, initial_balance: float = 10000.0):
19
+ self.initial_balance = initial_balance
20
  self.max_steps = 1000
21
  self.transaction_cost = 0.001
 
 
22
  self.learning_rate = 0.0001
23
  self.gamma = 0.99
24
  self.epsilon_start = 1.0
25
  self.epsilon_min = 0.01
26
+ self.epsilon_decay = 0.999
27
  self.batch_size = 32
28
+ self.memory_size = 5000 # Reduced for Spaces memory limits
29
+ self.target_update = 50
30
 
31
  class AdvancedTradingEnvironment:
32
+ """Simplified trading environment compatible with HF Spaces"""
33
+
34
+ def __init__(self, config: TradingConfig):
35
+ self.config = config
36
  self.initial_balance = config.initial_balance
37
  self.balance = self.initial_balance
38
  self.position = 0.0
 
41
  self.max_steps = config.max_steps
42
  self.price_history = []
43
  self.sentiment_history = []
44
+ self.action_history = []
45
+
46
+ # Initialize data
47
  self._initialize_data()
48
+
49
+ # Spaces: Discrete actions (0=Hold, 1=Buy, 2=Sell, 3=Close)
50
+ self.action_space = gr.utils.Discrete(4)
51
+ self.observation_space = None # For compatibility
52
+
53
  def _initialize_data(self):
54
+ """Initialize price and sentiment history"""
55
+ n_points = 50 # Reduced for faster init
56
  base_price = 100.0
57
+
58
  for i in range(n_points):
59
+ # Realistic price simulation
60
+ trend = np.sin(i * 0.2) * 5
61
+ noise = np.random.normal(0, 3)
62
+ price = max(50.0, base_price + trend + noise)
63
+ self.price_history.append(price)
64
+
65
+ # Correlated sentiment
66
+ sentiment = 0.5 + 0.3 * np.tanh((price - base_price) / base_price) + np.random.normal(0, 0.1)
67
  self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
68
+
69
  self.current_price = self.price_history[-1]
70
+
71
+ def reset(self) -> Tuple[np.ndarray, dict]:
72
+ """Reset environment to initial state"""
73
  self.balance = self.initial_balance
74
  self.position = 0.0
75
  self.step_count = 0
76
+ self.action_history = []
77
+
78
+ # Reset price series
79
+ self._initialize_data()
80
+
81
  obs = self._get_observation()
82
  info = self._get_info()
83
+
84
+ logger.info(f"Environment reset: balance=${self.balance}, price=${self.current_price}")
85
  return obs, info
86
+
87
+ def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
88
+ """Execute one step"""
89
  self.step_count += 1
90
+ self.action_history.append(action)
 
 
91
 
92
+ # Market evolution
93
+ self._market_step()
 
94
 
95
+ # Execute action
96
  reward = self._execute_action(action)
97
 
98
+ # Check termination
99
+ terminated = (self.balance <= 0 or self.step_count >= self.max_steps)
100
  truncated = False
101
 
102
  obs = self._get_observation()
103
  info = self._get_info()
104
 
105
  return obs, reward, terminated, truncated, info
106
+
107
+ def _market_step(self):
108
+ """Simulate market movement"""
109
+ # Price evolution (geometric brownian motion approximation)
110
+ drift = 0.0001
111
+ volatility = 0.02
112
+ price_change = drift + volatility * np.random.normal()
113
+ self.current_price = max(10.0, self.current_price * np.exp(price_change))
114
+
115
+ self.price_history.append(self.current_price)
116
+ if len(self.price_history) > 100:
117
+ self.price_history.pop(0)
118
+
119
+ # Update sentiment
120
+ sentiment_change = 0.1 * price_change + np.random.normal(0, 0.05)
121
+ new_sentiment = np.clip(self.sentiment_history[-1] + sentiment_change, 0.0, 1.0)
122
+ self.sentiment_history.append(new_sentiment)
123
+ if len(self.sentiment_history) > 100:
124
+ self.sentiment_history.pop(0)
125
+
126
+ def _execute_action(self, action: int) -> float:
127
+ """Execute trading action and return reward"""
128
  reward = 0.0
129
  prev_net_worth = self.balance + self.position * self.current_price
130
 
131
+ # Transaction cost
132
+ cost_factor = self.config.transaction_cost
133
+
134
  if action == 1: # Buy
135
+ if self.balance > 100: # Minimum trade size
136
+ trade_value = min(self.balance * 0.1, self.balance - 100)
137
+ cost = trade_value * (1 + cost_factor)
138
+ if cost <= self.balance:
139
+ shares = trade_value / self.current_price
140
+ self.position += shares
141
+ self.balance -= cost
142
+ reward -= cost_factor * trade_value
143
 
144
  elif action == 2: # Sell
145
  if self.position > 0:
146
+ sell_shares = min(self.position * 0.1, self.position)
147
+ proceeds = sell_shares * self.current_price * (1 - cost_factor)
148
+ self.position -= sell_shares
149
  self.balance += proceeds
150
+ reward -= cost_factor * sell_shares * self.current_price
151
 
152
+ elif action == 3: # Close position
153
  if self.position > 0:
154
+ proceeds = self.position * self.current_price * (1 - cost_factor)
155
  self.balance += proceeds
156
  self.position = 0
157
+ reward -= cost_factor * abs(self.position) * self.current_price * self.current_price
158
 
159
+ # Portfolio reward
160
+ current_net_worth = self.balance + self.position * self.current_price
161
+ pnl = (current_net_worth - prev_net_worth) / self.initial_balance
162
+ reward += pnl * 100 # Scale reward
163
 
164
  return reward
165
+
166
+ def _get_observation(self) -> np.ndarray:
167
+ """Get current state observation"""
168
+ # Use recent market data
169
+ recent_prices = self.price_history[-10:] if len(self.price_history) >= 10 else self.price_history
170
+ recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else self.sentiment_history
171
 
172
+ # Features
173
  features = [
174
+ self.balance / self.initial_balance, # Normalized balance
175
+ (self.position * self.current_price) / self.initial_balance, # Position value
176
+ self.current_price / 200.0, # Normalized price
177
+ np.mean(recent_prices) / 200.0, # Short-term average
178
+ np.std(recent_prices) / 200.0 if len(recent_prices) > 1 else 0, # Volatility
179
+ np.mean(recent_sentiments), # Average sentiment
180
+ np.std(recent_sentiments) if len(recent_sentiments) > 1 else 0, # Sentiment volatility
181
+ self.step_count / self.max_steps, # Progress
182
+ len([a for a in self.action_history[-5:] if a == 1]) / 5.0, # Recent buy ratio
183
+ len([a for a in self.action_history[-5:] if a == 2]) / 5.0, # Recent sell ratio
184
+ np.mean(np.diff(recent_prices)) if len(recent_prices) > 1 else 0, # Price momentum
185
+ 0.0 # Padding
186
  ]
187
 
188
  return np.array(features[:12], dtype=np.float32)
189
+
190
+ def _get_info(self) -> dict:
191
+ """Get environment info"""
192
  net_worth = self.balance + self.position * self.current_price
193
+ return {
194
+ 'net_worth': net_worth,
195
+ 'balance': self.balance,
196
+ 'position': self.position,
197
+ 'current_price': self.current_price,
198
+ 'step': self.step_count,
199
+ 'pnl': (net_worth - self.initial_balance) / self.initial_balance * 100
200
+ }
201
 
202
  class DQNAgent:
203
+ """Deep Q-Network agent"""
204
+
205
+ def __init__(self, state_dim: int, action_dim: int, config: TradingConfig, device: str = 'cpu'):
206
+ self.state_dim = state_dim
207
+ self.action_dim = action_dim
208
  self.device = torch.device(device)
209
+ self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ # Neural networks
212
+ self.q_network = self._build_network().to(self.device)
213
+ self.target_network = self._build_network().to(self.device)
214
  self.target_network.load_state_dict(self.q_network.state_dict())
215
 
216
  self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
217
  self.memory = deque(maxlen=config.memory_size)
 
218
  self.epsilon = config.epsilon_start
 
 
 
 
219
  self.steps = 0
220
+
221
+ def _build_network(self) -> torch.nn.Module:
222
+ return torch.nn.Sequential(
223
+ torch.nn.Linear(self.state_dim, 128),
224
+ torch.nn.ReLU(),
225
+ torch.nn.Linear(128, 128),
226
+ torch.nn.ReLU(),
227
+ torch.nn.Linear(128, self.action_dim)
228
+ )
229
+
230
+ def select_action(self, state: np.ndarray, training: bool = True) -> int:
231
+ """Epsilon-greedy action selection"""
232
  if training and random.random() < self.epsilon:
233
+ return random.randint(0, self.action_dim - 1)
234
+
235
+ state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
236
  with torch.no_grad():
237
+ q_values = self.q_network(state_tensor)
238
+ action = q_values.argmax().item()
239
+
240
+ if training:
241
+ self.epsilon = max(self.config.epsilon_min, self.epsilon * self.config.epsilon_decay)
242
+
243
+ return action
244
+
245
  def store_transition(self, state, action, reward, next_state, done):
246
+ """Store experience"""
247
+ self.memory.append((state, action, float(reward), next_state, done))
248
+
249
+ def update(self) -> float:
250
+ """Train the network"""
251
+ if len(self.memory) < self.config.batch_size:
252
  return 0.0
253
 
254
+ # Sample batch
255
+ batch = random.sample(self.memory, self.config.batch_size)
256
  states, actions, rewards, next_states, dones = zip(*batch)
257
 
258
  states = torch.FloatTensor(np.array(states)).to(self.device)
 
261
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
262
  dones = torch.FloatTensor(dones).to(self.device)
263
 
264
+ # Q-learning
265
  current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
266
  next_q = self.target_network(next_states).max(1)[0]
267
+ target_q = rewards + self.config.gamma * next_q * (1 - dones)
268
 
269
  loss = torch.nn.MSELoss()(current_q, target_q)
270
 
271
  self.optimizer.zero_grad()
272
  loss.backward()
273
+ torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
274
  self.optimizer.step()
275
 
276
  self.steps += 1
 
 
277
 
278
+ # Update target network
279
+ if self.steps % self.config.target_update == 0:
280
+ self.target_network.load_state_dict(self.q_network.state_dict())
281
 
282
  return loss.item()
283
 
284
  class TradingDemo:
285
+ """Main demo class with proper state management"""
286
+
287
  def __init__(self):
288
+ self.config = None
289
  self.env = None
290
  self.agent = None
291
  self.device = 'cpu'
292
+ self.is_initialized = False
293
+ self.training_history = []
294
+
295
+ def initialize(self, balance: float) -> str:
296
+ """Initialize environment and agent"""
297
+ try:
298
+ self.config = TradingConfig(initial_balance=float(balance))
299
+ self.env = AdvancedTradingEnvironment(self.config)
300
+ state_dim = 12 # Fixed observation size
301
+ self.agent = DQNAgent(state_dim, 4, self.config, self.device)
302
+ self.is_initialized = True
303
+ self.training_history = []
304
+
305
+ logger.info(f"Initialized: balance=${balance}")
306
+ return f"βœ… Environment initialized! Balance: ${balance:,.2f}"
307
+
308
+ except Exception as e:
309
+ logger.error(f"Initialization failed: {e}")
310
+ self.is_initialized = False
311
+ return f"❌ Initialization failed: {str(e)}"
312
+
313
+ def train(self, episodes: int) -> Iterator[Tuple[str, go.Figure]]:
314
+ """Train the agent with progress updates"""
315
+ if not self.is_initialized or self.env is None or self.agent is None:
316
+ yield "❌ Please initialize first!", None
317
+ return
318
+
319
+ try:
320
+ episodes = int(episodes)
321
+ episode_rewards = []
322
+
323
+ for ep in range(episodes):
324
+ obs, _ = self.env.reset()
325
+ total_reward = 0
326
+ done = False
327
+
328
+ while not done:
329
+ action = self.agent.select_action(obs, training=True)
330
+ next_obs, reward, done, _, info = self.env.step(action)
331
+
332
+ self.agent.store_transition(obs, action, reward, next_obs, done)
333
+ obs = next_obs
334
+ total_reward += reward
335
+
336
+ # Update agent
337
+ loss = self.agent.update()
338
+ episode_rewards.append(total_reward)
339
+ self.training_history.append({
340
+ 'episode': ep + 1,
341
+ 'reward': total_reward,
342
+ 'loss': loss
343
+ })
344
+
345
+ # Progress update every 10 episodes
346
+ if (ep + 1) % 10 == 0 or ep == episodes - 1:
347
+ avg_reward = np.mean(episode_rewards[-10:])
348
+ progress = (ep + 1) / episodes * 100
349
+
350
+ status = (f"πŸ“ˆ Training Progress: {ep+1}/{episodes} ({progress:.1f}%)\n"
351
+ f"🎯 Episode Reward: {total_reward:.2f}\n"
352
+ f"πŸ“Š Avg Reward (last 10): {avg_reward:.2f}\n"
353
+ f"πŸ“‰ Loss: {loss:.4f}")
354
+
355
+ # Create progress chart
356
+ chart = self._create_training_chart()
357
+ yield status, chart
358
+
359
+ final_status = f"βœ… Training completed! Final Avg Reward: {np.mean(episode_rewards):.2f}"
360
+ yield final_status, self._create_training_chart()
361
+
362
+ except Exception as e:
363
+ logger.error(f"Training error: {e}")
364
+ yield f"❌ Training failed: {str(e)}", None
365
+
366
+ def simulate(self, steps: int) -> Tuple[str, go.Figure, go.Figure]:
367
+ """Run trading simulation"""
368
+ if not self.is_initialized or self.env is None or self.agent is None:
369
+ return "❌ Please initialize and train first!", None, None
370
+
371
+ try:
372
+ steps = int(steps)
373
  obs, _ = self.env.reset()
374
+
375
+ prices = []
376
+ actions = []
377
+ net_worths = []
378
+ rewards = []
379
+
380
+ for step in range(steps):
381
+ action = self.agent.select_action(obs, training=False)
382
  next_obs, reward, done, _, info = self.env.step(action)
383
+
384
+ prices.append(self.env.current_price)
385
+ actions.append(action)
386
+ net_worths.append(info['net_worth'])
387
+ rewards.append(reward)
388
+
389
  obs = next_obs
390
+ if done:
391
+ break
392
+
393
+ # Create charts
394
+ price_chart = self._create_price_chart(prices, actions)
395
+ performance_chart = self._create_performance_chart(net_worths, rewards)
396
+
397
+ final_pnl = (net_worths[-1] - self.config.initial_balance) / self.config.initial_balance * 100
398
+ status = f"βœ… Simulation completed! Steps: {len(prices)}, Final P&L: {final_pnl:+.2f}%"
399
+
400
+ return status, price_chart, performance_chart
401
+
402
+ except Exception as e:
403
+ logger.error(f"Simulation error: {e}")
404
+ return f"❌ Simulation failed: {str(e)}", None, None
405
+
406
+ def _create_training_chart(self) -> go.Figure:
407
+ """Create training progress chart"""
408
+ if not self.training_history:
409
+ fig = go.Figure()
410
+ fig.update_layout(title="Training in progress...", height=400)
411
+ return fig
412
+
413
+ episodes = [h['episode'] for h in self.training_history]
414
+ rewards = [h['reward'] for h in self.training_history]
415
+ losses = [h['loss'] for h in self.training_history]
416
+
417
+ fig = make_subplots(rows=2, cols=1, subplot_titles=["Rewards", "Loss"])
418
+
419
+ fig.add_trace(go.Scatter(x=episodes, y=rewards, mode='lines+markers', name='Reward'), row=1, col=1)
420
+ fig.add_trace(go.Scatter(x=episodes, y=losses, mode='lines', name='Loss'), row=2, col=1)
421
+
422
+ fig.update_layout(height=500, title="Training Progress", showlegend=True)
423
+ return fig
424
+
425
+ def _create_price_chart(self, prices: list, actions: list) -> go.Figure:
426
+ """Create price action chart"""
427
  fig = go.Figure()
428
+
429
+ # Price line
430
+ fig.add_trace(go.Scatter(x=list(range(len(prices))), y=prices, mode='lines',
431
+ name='Price', line=dict(color='blue', width=2)))
432
+
433
+ # Action markers
434
+ buy_indices = [i for i, a in enumerate(actions) if a == 1]
435
+ sell_indices = [i for i, a in enumerate(actions) if a == 2]
436
+
437
+ if buy_indices:
438
+ buy_prices = [prices[i] for i in buy_indices]
439
+ fig.add_trace(go.Scatter(x=buy_indices, y=buy_prices, mode='markers',
440
+ name='Buy', marker=dict(color='green', size=10, symbol='triangle-up')))
441
+
442
+ if sell_indices:
443
+ sell_prices = [prices[i] for i in sell_indices]
444
+ fig.add_trace(go.Scatter(x=sell_indices, y=sell_prices, mode='markers',
445
+ name='Sell', marker=dict(color='red', size=10, symbol='triangle-down')))
446
+
447
+ fig.update_layout(title="Price Action Simulation", height=400, showlegend=True)
448
+ return fig
449
+
450
+ def _create_performance_chart(self, net_worths: list, rewards: list) -> go.Figure:
451
+ """Create performance dashboard"""
452
+ fig = make_subplots(rows=2, cols=1, subplot_titles=["Net Worth", "Rewards"])
453
+
454
+ fig.add_trace(go.Scatter(x=list(range(len(net_worths))), y=net_worths,
455
+ mode='lines', name='Net Worth'), row=1, col=1)
456
+ fig.add_trace(go.Bar(x=list(range(len(rewards))), y=rewards,
457
+ name='Rewards'), row=2, col=1)
458
+
459
+ fig.update_layout(title="Performance", height=500, showlegend=True)
460
+ return fig
461
 
462
+ # Initialize demo
463
  demo = TradingDemo()
464
 
465
+ # Create Gradio interface
466
+ with gr.Blocks(theme=gr.themes.Soft(), title="πŸ€– AI Trading Demo") as interface:
467
+ gr.Markdown("# πŸš€ AI Trading Assistant\nReinforcement Learning Trading Demo")
468
 
469
+ # Configuration
470
  with gr.Row():
471
+ with gr.Column(scale=1):
472
+ gr.Markdown("## βš™οΈ Configuration")
473
+ balance = gr.Slider(1000, 50000, value=10000, step=1000, label="Initial Balance ($)")
474
+ init_btn = gr.Button("πŸš€ Initialize Environment", variant="primary")
475
+ status = gr.Textbox(label="Status", interactive=False, lines=2)
476
+
477
+ # Training
478
+ with gr.Row():
479
+ with gr.Column(scale=1):
480
+ gr.Markdown("## πŸŽ“ Training")
481
+ episodes = gr.Slider(10, 100, value=30, step=5, label="Training Episodes")
482
+ train_btn = gr.Button("πŸ€– Start Training", variant="primary")
483
+
484
+ with gr.Column(scale=2):
485
+ train_status = gr.Textbox(label="Training Progress", interactive=False, lines=3)
486
+ train_chart = gr.Plot(label="Training Metrics")
487
+
488
+ # Simulation
489
+ with gr.Row():
490
+ with gr.Column(scale=1):
491
+ gr.Markdown("## 🎯 Simulation")
492
+ sim_steps = gr.Slider(20, 200, value=50, step=10, label="Simulation Steps")
493
+ sim_btn = gr.Button("▢️ Run Simulation", variant="secondary")
494
+
495
+ with gr.Column(scale=2):
496
+ sim_status = gr.Textbox(label="Simulation Status", interactive=False)
497
+ price_chart = gr.Plot(label="Price Chart")
498
+
499
+ performance_chart = gr.Plot(label="Performance Dashboard")
500
+
501
+ # Event handlers with proper error handling
502
+ init_btn.click(
503
+ demo.initialize,
504
+ inputs=[balance],
505
+ outputs=[status]
506
+ )
507
 
508
+ train_btn.click(
509
+ demo.train,
510
+ inputs=[episodes],
511
+ outputs=[train_status, train_chart]
512
+ )
513
 
514
+ sim_btn.click(
515
+ demo.simulate,
516
+ inputs=[sim_steps],
517
+ outputs=[sim_status, price_chart, performance_chart]
518
+ )
519
 
520
+ gr.Markdown("""
521
+ ## πŸ“ Usage Instructions:
522
+ 1. **Initialize**: Set balance and click "Initialize Environment"
523
+ 2. **Train**: Adjust episodes and click "Start Training"
524
+ 3. **Simulate**: Set steps and click "Run Simulation"
525
 
526
+ **Note**: Training must complete before simulation!
527
+ """)
 
528
 
529
+ if __name__ == "__main__":
530
+ interface.launch(
531
+ server_name="0.0.0.0",
532
+ server_port=7860,
533
+ share=False,
534
+ debug=True
535
+ )