OmidSakaki commited on
Commit
8f54cbf
·
verified ·
1 Parent(s): 769c366

Update src/environments/visual_trading_env.py

Browse files
Files changed (1) hide show
  1. src/environments/visual_trading_env.py +102 -85
src/environments/visual_trading_env.py CHANGED
@@ -1,18 +1,17 @@
1
  import numpy as np
2
- import pandas as pd
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
5
  import io
6
 
7
  class VisualTradingEnvironment:
8
  def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
9
- self.initial_balance = initial_balance
10
  self.risk_level = risk_level
11
  self.asset_type = asset_type
12
 
13
  # Risk multipliers
14
  risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
15
- self.risk_multiplier = risk_multipliers[risk_level]
16
 
17
  # Generate market data
18
  self._generate_market_data()
@@ -25,21 +24,23 @@ class VisualTradingEnvironment:
25
  np.random.seed(42)
26
 
27
  # Base parameters based on asset type
28
- if self.asset_type == "Crypto":
29
- volatility = 0.02 * self.risk_multiplier
30
- trend = 0.001
31
- elif self.asset_type == "Forex":
32
- volatility = 0.005 * self.risk_multiplier
33
- trend = 0.0002
34
- else: # Stock
35
- volatility = 0.01 * self.risk_multiplier
36
- trend = 0.0005
37
 
38
  prices = [100.0]
39
  for i in range(1, num_points):
40
- # Random walk with trend and volatility
41
  change = np.random.normal(trend, volatility)
42
- price = max(1.0, prices[-1] * (1 + change))
 
 
43
  prices.append(price)
44
 
45
  self.price_data = np.array(prices)
@@ -61,7 +62,7 @@ class VisualTradingEnvironment:
61
  fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
62
  ax.set_facecolor('black')
63
 
64
- # Plot price
65
  if len(prices) > 0:
66
  ax.plot(prices, color='cyan', linewidth=1.5)
67
 
@@ -97,7 +98,7 @@ class VisualTradingEnvironment:
97
  plt.close(fig)
98
 
99
  # Create attention map with same dimensions
100
- attention_map = np.zeros((84, 84))
101
  if len(prices) > 1:
102
  recent_change = (prices[-1] - prices[-2]) / prices[-2]
103
  intensity = min(255, abs(recent_change) * 5000)
@@ -110,12 +111,10 @@ class VisualTradingEnvironment:
110
  for j in range(max(0, center_y-size), min(84, center_y+size)):
111
  distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
112
  if distance <= size:
113
- attention_map[i, j] = max(attention_map[i, j], intensity * (1 - distance/size))
114
-
115
- # Ensure both arrays have same shape before concatenation
116
- attention_map = attention_map.astype(np.uint8)
117
 
118
- # Combine RGB with attention map - now both are 84x84
119
  visual_obs = np.concatenate([
120
  img_array,
121
  attention_map[:, :, np.newaxis] # Add channel dimension
@@ -132,8 +131,8 @@ class VisualTradingEnvironment:
132
  """Reset environment to initial state"""
133
  self.current_step = 50 # Start with some history
134
  self.balance = self.initial_balance
135
- self.position_size = 0
136
- self.entry_price = 0
137
  self.net_worth = self.initial_balance
138
  self.total_trades = 0
139
  self.done = False
@@ -142,70 +141,88 @@ class VisualTradingEnvironment:
142
 
143
  def step(self, action):
144
  """Execute one trading step"""
145
- current_price = self.price_data[self.current_step]
146
- prev_net_worth = self.net_worth
147
-
148
- reward = 0
149
-
150
- # Execute action
151
- if action == 1 and self.position_size == 0: # Buy
152
- # Risk-adjusted position sizing
153
- position_value = self.balance * 0.1 * self.risk_multiplier
154
- self.position_size = position_value / current_price
155
- self.entry_price = current_price
156
- self.balance -= position_value
157
- self.total_trades += 1
158
- reward = -0.01 # Small penalty for transaction
159
-
160
- elif action == 2 and self.position_size > 0: # Sell (increase position)
161
- additional_value = self.balance * 0.05 * self.risk_multiplier
162
- additional_size = additional_value / current_price
163
- self.position_size += additional_size
164
- self.balance -= additional_value
165
- self.total_trades += 1
166
- reward = -0.005
167
-
168
- elif action == 3 and self.position_size > 0: # Close position
169
- close_value = self.position_size * current_price
170
- self.balance += close_value
171
- profit_loss = (current_price - self.entry_price) / self.entry_price
172
- reward = profit_loss * 10 # Scale profit/loss
173
- self.position_size = 0
174
- self.entry_price = 0
175
- self.total_trades += 1
176
-
177
- # Update net worth
178
- position_value = self.position_size * current_price if self.position_size > 0 else 0
179
- self.net_worth = self.balance + position_value
180
-
181
- # Add small reward for portfolio growth
182
- portfolio_change = (self.net_worth - prev_net_worth) / prev_net_worth
183
- reward += portfolio_change * 5
184
-
185
- # Move to next step
186
- self.current_step += 1
187
-
188
- # Check if episode is done
189
- if self.current_step >= len(self.price_data) - 1:
190
- self.done = True
191
- # Final reward based on overall performance
192
- final_return = (self.net_worth - self.initial_balance) / self.initial_balance
193
- reward += final_return * 20
194
-
195
- info = {
196
- 'net_worth': self.net_worth,
197
- 'balance': self.balance,
198
- 'position_size': self.position_size,
199
- 'current_price': current_price,
200
- 'total_trades': self.total_trades,
201
- 'step': self.current_step
202
- }
203
-
204
- obs = self._get_visual_observation()
205
- return obs, reward, self.done, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  def get_price_history(self):
208
  """Get recent price history for visualization"""
209
  window_size = min(50, self.current_step)
210
  start_idx = max(0, self.current_step - window_size)
211
- return self.price_data[start_idx:self.current_step].tolist()
 
 
1
  import numpy as np
 
2
  import matplotlib.pyplot as plt
3
  from PIL import Image
4
  import io
5
 
6
  class VisualTradingEnvironment:
7
  def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
8
+ self.initial_balance = float(initial_balance)
9
  self.risk_level = risk_level
10
  self.asset_type = asset_type
11
 
12
  # Risk multipliers
13
  risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
14
+ self.risk_multiplier = risk_multipliers.get(risk_level, 1.0)
15
 
16
  # Generate market data
17
  self._generate_market_data()
 
24
  np.random.seed(42)
25
 
26
  # Base parameters based on asset type
27
+ base_params = {
28
+ "Stock": {"volatility": 0.01, "trend": 0.0005},
29
+ "Crypto": {"volatility": 0.02, "trend": 0.001},
30
+ "Forex": {"volatility": 0.005, "trend": 0.0002}
31
+ }
32
+
33
+ params = base_params.get(self.asset_type, base_params["Stock"])
34
+ volatility = params["volatility"] * self.risk_multiplier
35
+ trend = params["trend"]
36
 
37
  prices = [100.0]
38
  for i in range(1, num_points):
39
+ # Random walk with trend and some mean reversion
40
  change = np.random.normal(trend, volatility)
41
+ # Add some mean reversion
42
+ mean_reversion = (100 - prices[-1]) * 0.001
43
+ price = max(1.0, prices[-1] * (1 + change) + mean_reversion)
44
  prices.append(price)
45
 
46
  self.price_data = np.array(prices)
 
62
  fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
63
  ax.set_facecolor('black')
64
 
65
+ # Plot price if we have data
66
  if len(prices) > 0:
67
  ax.plot(prices, color='cyan', linewidth=1.5)
68
 
 
98
  plt.close(fig)
99
 
100
  # Create attention map with same dimensions
101
+ attention_map = np.zeros((84, 84), dtype=np.uint8)
102
  if len(prices) > 1:
103
  recent_change = (prices[-1] - prices[-2]) / prices[-2]
104
  intensity = min(255, abs(recent_change) * 5000)
 
111
  for j in range(max(0, center_y-size), min(84, center_y+size)):
112
  distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
113
  if distance <= size:
114
+ attention_value = intensity * (1 - distance/size)
115
+ attention_map[i, j] = max(attention_map[i, j], int(attention_value))
 
 
116
 
117
+ # Combine RGB with attention map
118
  visual_obs = np.concatenate([
119
  img_array,
120
  attention_map[:, :, np.newaxis] # Add channel dimension
 
131
  """Reset environment to initial state"""
132
  self.current_step = 50 # Start with some history
133
  self.balance = self.initial_balance
134
+ self.position_size = 0.0
135
+ self.entry_price = 0.0
136
  self.net_worth = self.initial_balance
137
  self.total_trades = 0
138
  self.done = False
 
141
 
142
  def step(self, action):
143
  """Execute one trading step"""
144
+ try:
145
+ current_price = self.price_data[self.current_step]
146
+ prev_net_worth = self.net_worth
147
+
148
+ reward = 0.0
149
+
150
+ # Execute action
151
+ if action == 1 and self.position_size == 0: # Buy
152
+ # Risk-adjusted position sizing
153
+ position_value = self.balance * 0.1 * self.risk_multiplier
154
+ self.position_size = position_value / current_price
155
+ self.entry_price = current_price
156
+ self.balance -= position_value
157
+ self.total_trades += 1
158
+ reward = -0.01 # Small penalty for transaction
159
+
160
+ elif action == 2 and self.position_size > 0: # Sell (increase position)
161
+ additional_value = self.balance * 0.05 * self.risk_multiplier
162
+ additional_size = additional_value / current_price
163
+ self.position_size += additional_size
164
+ self.balance -= additional_value
165
+ self.total_trades += 1
166
+ reward = -0.005
167
+
168
+ elif action == 3 and self.position_size > 0: # Close position
169
+ close_value = self.position_size * current_price
170
+ self.balance += close_value
171
+ if self.entry_price > 0:
172
+ profit_loss = (current_price - self.entry_price) / self.entry_price
173
+ reward = profit_loss * 10 # Scale profit/loss
174
+ self.position_size = 0.0
175
+ self.entry_price = 0.0
176
+ self.total_trades += 1
177
+
178
+ # Update net worth
179
+ position_value = self.position_size * current_price if self.position_size > 0 else 0.0
180
+ self.net_worth = self.balance + position_value
181
+
182
+ # Add small reward for portfolio growth
183
+ if prev_net_worth > 0:
184
+ portfolio_change = (self.net_worth - prev_net_worth) / prev_net_worth
185
+ reward += portfolio_change * 5
186
+
187
+ # Move to next step
188
+ self.current_step += 1
189
+
190
+ # Check if episode is done
191
+ if self.current_step >= len(self.price_data) - 1:
192
+ self.done = True
193
+ # Final reward based on overall performance
194
+ if self.initial_balance > 0:
195
+ final_return = (self.net_worth - self.initial_balance) / self.initial_balance
196
+ reward += final_return * 20
197
+
198
+ info = {
199
+ 'net_worth': float(self.net_worth),
200
+ 'balance': float(self.balance),
201
+ 'position_size': float(self.position_size),
202
+ 'current_price': float(current_price),
203
+ 'total_trades': int(self.total_trades),
204
+ 'step': int(self.current_step)
205
+ }
206
+
207
+ obs = self._get_visual_observation()
208
+ return obs, float(reward), bool(self.done), info
209
+
210
+ except Exception as e:
211
+ print(f"Error in step execution: {e}")
212
+ # Return safe default values in case of error
213
+ default_info = {
214
+ 'net_worth': float(self.initial_balance),
215
+ 'balance': float(self.initial_balance),
216
+ 'position_size': 0.0,
217
+ 'current_price': 100.0,
218
+ 'total_trades': 0,
219
+ 'step': int(self.current_step)
220
+ }
221
+ return self._get_visual_observation(), 0.0, True, default_info
222
 
223
  def get_price_history(self):
224
  """Get recent price history for visualization"""
225
  window_size = min(50, self.current_step)
226
  start_idx = max(0, self.current_step - window_size)
227
+ prices = self.price_data[start_idx:self.current_step]
228
+ return [float(price) for price in prices]