Spaces:
Sleeping
Sleeping
Update src/environments/visual_trading_env.py
Browse files
src/environments/visual_trading_env.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
import pandas as pd
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
from PIL import Image
|
| 5 |
import io
|
| 6 |
|
| 7 |
class VisualTradingEnvironment:
|
| 8 |
def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
|
| 9 |
-
self.initial_balance = initial_balance
|
| 10 |
self.risk_level = risk_level
|
| 11 |
self.asset_type = asset_type
|
| 12 |
|
| 13 |
# Risk multipliers
|
| 14 |
risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
|
| 15 |
-
self.risk_multiplier = risk_multipliers
|
| 16 |
|
| 17 |
# Generate market data
|
| 18 |
self._generate_market_data()
|
|
@@ -25,21 +24,23 @@ class VisualTradingEnvironment:
|
|
| 25 |
np.random.seed(42)
|
| 26 |
|
| 27 |
# Base parameters based on asset type
|
| 28 |
-
|
| 29 |
-
volatility
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
prices = [100.0]
|
| 39 |
for i in range(1, num_points):
|
| 40 |
-
# Random walk with trend and
|
| 41 |
change = np.random.normal(trend, volatility)
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
prices.append(price)
|
| 44 |
|
| 45 |
self.price_data = np.array(prices)
|
|
@@ -61,7 +62,7 @@ class VisualTradingEnvironment:
|
|
| 61 |
fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
|
| 62 |
ax.set_facecolor('black')
|
| 63 |
|
| 64 |
-
# Plot price
|
| 65 |
if len(prices) > 0:
|
| 66 |
ax.plot(prices, color='cyan', linewidth=1.5)
|
| 67 |
|
|
@@ -97,7 +98,7 @@ class VisualTradingEnvironment:
|
|
| 97 |
plt.close(fig)
|
| 98 |
|
| 99 |
# Create attention map with same dimensions
|
| 100 |
-
attention_map = np.zeros((84, 84))
|
| 101 |
if len(prices) > 1:
|
| 102 |
recent_change = (prices[-1] - prices[-2]) / prices[-2]
|
| 103 |
intensity = min(255, abs(recent_change) * 5000)
|
|
@@ -110,12 +111,10 @@ class VisualTradingEnvironment:
|
|
| 110 |
for j in range(max(0, center_y-size), min(84, center_y+size)):
|
| 111 |
distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
|
| 112 |
if distance <= size:
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
# Ensure both arrays have same shape before concatenation
|
| 116 |
-
attention_map = attention_map.astype(np.uint8)
|
| 117 |
|
| 118 |
-
# Combine RGB with attention map
|
| 119 |
visual_obs = np.concatenate([
|
| 120 |
img_array,
|
| 121 |
attention_map[:, :, np.newaxis] # Add channel dimension
|
|
@@ -132,8 +131,8 @@ class VisualTradingEnvironment:
|
|
| 132 |
"""Reset environment to initial state"""
|
| 133 |
self.current_step = 50 # Start with some history
|
| 134 |
self.balance = self.initial_balance
|
| 135 |
-
self.position_size = 0
|
| 136 |
-
self.entry_price = 0
|
| 137 |
self.net_worth = self.initial_balance
|
| 138 |
self.total_trades = 0
|
| 139 |
self.done = False
|
|
@@ -142,70 +141,88 @@ class VisualTradingEnvironment:
|
|
| 142 |
|
| 143 |
def step(self, action):
|
| 144 |
"""Execute one trading step"""
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
#
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
def get_price_history(self):
|
| 208 |
"""Get recent price history for visualization"""
|
| 209 |
window_size = min(50, self.current_step)
|
| 210 |
start_idx = max(0, self.current_step - window_size)
|
| 211 |
-
|
|
|
|
|
|
| 1 |
import numpy as np
|
|
|
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
from PIL import Image
|
| 4 |
import io
|
| 5 |
|
| 6 |
class VisualTradingEnvironment:
|
| 7 |
def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
|
| 8 |
+
self.initial_balance = float(initial_balance)
|
| 9 |
self.risk_level = risk_level
|
| 10 |
self.asset_type = asset_type
|
| 11 |
|
| 12 |
# Risk multipliers
|
| 13 |
risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
|
| 14 |
+
self.risk_multiplier = risk_multipliers.get(risk_level, 1.0)
|
| 15 |
|
| 16 |
# Generate market data
|
| 17 |
self._generate_market_data()
|
|
|
|
| 24 |
np.random.seed(42)
|
| 25 |
|
| 26 |
# Base parameters based on asset type
|
| 27 |
+
base_params = {
|
| 28 |
+
"Stock": {"volatility": 0.01, "trend": 0.0005},
|
| 29 |
+
"Crypto": {"volatility": 0.02, "trend": 0.001},
|
| 30 |
+
"Forex": {"volatility": 0.005, "trend": 0.0002}
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
params = base_params.get(self.asset_type, base_params["Stock"])
|
| 34 |
+
volatility = params["volatility"] * self.risk_multiplier
|
| 35 |
+
trend = params["trend"]
|
| 36 |
|
| 37 |
prices = [100.0]
|
| 38 |
for i in range(1, num_points):
|
| 39 |
+
# Random walk with trend and some mean reversion
|
| 40 |
change = np.random.normal(trend, volatility)
|
| 41 |
+
# Add some mean reversion
|
| 42 |
+
mean_reversion = (100 - prices[-1]) * 0.001
|
| 43 |
+
price = max(1.0, prices[-1] * (1 + change) + mean_reversion)
|
| 44 |
prices.append(price)
|
| 45 |
|
| 46 |
self.price_data = np.array(prices)
|
|
|
|
| 62 |
fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
|
| 63 |
ax.set_facecolor('black')
|
| 64 |
|
| 65 |
+
# Plot price if we have data
|
| 66 |
if len(prices) > 0:
|
| 67 |
ax.plot(prices, color='cyan', linewidth=1.5)
|
| 68 |
|
|
|
|
| 98 |
plt.close(fig)
|
| 99 |
|
| 100 |
# Create attention map with same dimensions
|
| 101 |
+
attention_map = np.zeros((84, 84), dtype=np.uint8)
|
| 102 |
if len(prices) > 1:
|
| 103 |
recent_change = (prices[-1] - prices[-2]) / prices[-2]
|
| 104 |
intensity = min(255, abs(recent_change) * 5000)
|
|
|
|
| 111 |
for j in range(max(0, center_y-size), min(84, center_y+size)):
|
| 112 |
distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
|
| 113 |
if distance <= size:
|
| 114 |
+
attention_value = intensity * (1 - distance/size)
|
| 115 |
+
attention_map[i, j] = max(attention_map[i, j], int(attention_value))
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
# Combine RGB with attention map
|
| 118 |
visual_obs = np.concatenate([
|
| 119 |
img_array,
|
| 120 |
attention_map[:, :, np.newaxis] # Add channel dimension
|
|
|
|
| 131 |
"""Reset environment to initial state"""
|
| 132 |
self.current_step = 50 # Start with some history
|
| 133 |
self.balance = self.initial_balance
|
| 134 |
+
self.position_size = 0.0
|
| 135 |
+
self.entry_price = 0.0
|
| 136 |
self.net_worth = self.initial_balance
|
| 137 |
self.total_trades = 0
|
| 138 |
self.done = False
|
|
|
|
| 141 |
|
| 142 |
def step(self, action):
|
| 143 |
"""Execute one trading step"""
|
| 144 |
+
try:
|
| 145 |
+
current_price = self.price_data[self.current_step]
|
| 146 |
+
prev_net_worth = self.net_worth
|
| 147 |
+
|
| 148 |
+
reward = 0.0
|
| 149 |
+
|
| 150 |
+
# Execute action
|
| 151 |
+
if action == 1 and self.position_size == 0: # Buy
|
| 152 |
+
# Risk-adjusted position sizing
|
| 153 |
+
position_value = self.balance * 0.1 * self.risk_multiplier
|
| 154 |
+
self.position_size = position_value / current_price
|
| 155 |
+
self.entry_price = current_price
|
| 156 |
+
self.balance -= position_value
|
| 157 |
+
self.total_trades += 1
|
| 158 |
+
reward = -0.01 # Small penalty for transaction
|
| 159 |
+
|
| 160 |
+
elif action == 2 and self.position_size > 0: # Sell (increase position)
|
| 161 |
+
additional_value = self.balance * 0.05 * self.risk_multiplier
|
| 162 |
+
additional_size = additional_value / current_price
|
| 163 |
+
self.position_size += additional_size
|
| 164 |
+
self.balance -= additional_value
|
| 165 |
+
self.total_trades += 1
|
| 166 |
+
reward = -0.005
|
| 167 |
+
|
| 168 |
+
elif action == 3 and self.position_size > 0: # Close position
|
| 169 |
+
close_value = self.position_size * current_price
|
| 170 |
+
self.balance += close_value
|
| 171 |
+
if self.entry_price > 0:
|
| 172 |
+
profit_loss = (current_price - self.entry_price) / self.entry_price
|
| 173 |
+
reward = profit_loss * 10 # Scale profit/loss
|
| 174 |
+
self.position_size = 0.0
|
| 175 |
+
self.entry_price = 0.0
|
| 176 |
+
self.total_trades += 1
|
| 177 |
+
|
| 178 |
+
# Update net worth
|
| 179 |
+
position_value = self.position_size * current_price if self.position_size > 0 else 0.0
|
| 180 |
+
self.net_worth = self.balance + position_value
|
| 181 |
+
|
| 182 |
+
# Add small reward for portfolio growth
|
| 183 |
+
if prev_net_worth > 0:
|
| 184 |
+
portfolio_change = (self.net_worth - prev_net_worth) / prev_net_worth
|
| 185 |
+
reward += portfolio_change * 5
|
| 186 |
+
|
| 187 |
+
# Move to next step
|
| 188 |
+
self.current_step += 1
|
| 189 |
+
|
| 190 |
+
# Check if episode is done
|
| 191 |
+
if self.current_step >= len(self.price_data) - 1:
|
| 192 |
+
self.done = True
|
| 193 |
+
# Final reward based on overall performance
|
| 194 |
+
if self.initial_balance > 0:
|
| 195 |
+
final_return = (self.net_worth - self.initial_balance) / self.initial_balance
|
| 196 |
+
reward += final_return * 20
|
| 197 |
+
|
| 198 |
+
info = {
|
| 199 |
+
'net_worth': float(self.net_worth),
|
| 200 |
+
'balance': float(self.balance),
|
| 201 |
+
'position_size': float(self.position_size),
|
| 202 |
+
'current_price': float(current_price),
|
| 203 |
+
'total_trades': int(self.total_trades),
|
| 204 |
+
'step': int(self.current_step)
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
obs = self._get_visual_observation()
|
| 208 |
+
return obs, float(reward), bool(self.done), info
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Error in step execution: {e}")
|
| 212 |
+
# Return safe default values in case of error
|
| 213 |
+
default_info = {
|
| 214 |
+
'net_worth': float(self.initial_balance),
|
| 215 |
+
'balance': float(self.initial_balance),
|
| 216 |
+
'position_size': 0.0,
|
| 217 |
+
'current_price': 100.0,
|
| 218 |
+
'total_trades': 0,
|
| 219 |
+
'step': int(self.current_step)
|
| 220 |
+
}
|
| 221 |
+
return self._get_visual_observation(), 0.0, True, default_info
|
| 222 |
|
| 223 |
def get_price_history(self):
|
| 224 |
"""Get recent price history for visualization"""
|
| 225 |
window_size = min(50, self.current_step)
|
| 226 |
start_idx = max(0, self.current_step - window_size)
|
| 227 |
+
prices = self.price_data[start_idx:self.current_step]
|
| 228 |
+
return [float(price) for price in prices]
|