VisualTradingAI / src /environments /visual_trading_env.py
OmidSakaki's picture
Update src/environments/visual_trading_env.py
8f54cbf verified
raw
history blame
9.5 kB
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
class VisualTradingEnvironment:
def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
self.initial_balance = float(initial_balance)
self.risk_level = risk_level
self.asset_type = asset_type
# Risk multipliers
risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
self.risk_multiplier = risk_multipliers.get(risk_level, 1.0)
# Generate market data
self._generate_market_data()
# Initialize state
self.reset()
def _generate_market_data(self, num_points=1000):
"""Generate realistic synthetic market data"""
np.random.seed(42)
# Base parameters based on asset type
base_params = {
"Stock": {"volatility": 0.01, "trend": 0.0005},
"Crypto": {"volatility": 0.02, "trend": 0.001},
"Forex": {"volatility": 0.005, "trend": 0.0002}
}
params = base_params.get(self.asset_type, base_params["Stock"])
volatility = params["volatility"] * self.risk_multiplier
trend = params["trend"]
prices = [100.0]
for i in range(1, num_points):
# Random walk with trend and some mean reversion
change = np.random.normal(trend, volatility)
# Add some mean reversion
mean_reversion = (100 - prices[-1]) * 0.001
price = max(1.0, prices[-1] * (1 + change) + mean_reversion)
prices.append(price)
self.price_data = np.array(prices)
def _get_visual_observation(self):
"""Generate visual representation of current market state"""
try:
# Get recent price window
window_size = 50
start_idx = max(0, self.current_step - window_size)
end_idx = self.current_step + 1
if end_idx > len(self.price_data):
end_idx = len(self.price_data)
prices = self.price_data[start_idx:end_idx]
# Create matplotlib figure with fixed size
fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
ax.set_facecolor('black')
# Plot price if we have data
if len(prices) > 0:
ax.plot(prices, color='cyan', linewidth=1.5)
# Remove axes for cleaner visual
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# Set fixed limits to ensure consistent size
ax.set_xlim(0, 50)
if len(prices) > 0:
price_min, price_max = min(prices), max(prices)
price_range = price_max - price_min
if price_range == 0:
price_range = 1
ax.set_ylim(price_min - price_range * 0.1, price_max + price_range * 0.1)
else:
ax.set_ylim(0, 100)
# Convert to numpy array with consistent size
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, facecolor='black', dpi=20)
buf.seek(0)
img = Image.open(buf).convert('RGB')
# Resize to consistent dimensions
img = img.resize((84, 84), Image.Resampling.LANCZOS)
img_array = np.array(img)
plt.close(fig)
# Create attention map with same dimensions
attention_map = np.zeros((84, 84), dtype=np.uint8)
if len(prices) > 1:
recent_change = (prices[-1] - prices[-2]) / prices[-2]
intensity = min(255, abs(recent_change) * 5000)
# Simple attention based on price movement
center_x, center_y = 42, 42
size = max(5, int(intensity / 50))
for i in range(max(0, center_x-size), min(84, center_x+size)):
for j in range(max(0, center_y-size), min(84, center_y+size)):
distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
if distance <= size:
attention_value = intensity * (1 - distance/size)
attention_map[i, j] = max(attention_map[i, j], int(attention_value))
# Combine RGB with attention map
visual_obs = np.concatenate([
img_array,
attention_map[:, :, np.newaxis] # Add channel dimension
], axis=2)
return visual_obs
except Exception as e:
print(f"Error in visual observation: {e}")
# Return default observation in case of error
return np.zeros((84, 84, 4), dtype=np.uint8)
def reset(self):
"""Reset environment to initial state"""
self.current_step = 50 # Start with some history
self.balance = self.initial_balance
self.position_size = 0.0
self.entry_price = 0.0
self.net_worth = self.initial_balance
self.total_trades = 0
self.done = False
return self._get_visual_observation()
def step(self, action):
"""Execute one trading step"""
try:
current_price = self.price_data[self.current_step]
prev_net_worth = self.net_worth
reward = 0.0
# Execute action
if action == 1 and self.position_size == 0: # Buy
# Risk-adjusted position sizing
position_value = self.balance * 0.1 * self.risk_multiplier
self.position_size = position_value / current_price
self.entry_price = current_price
self.balance -= position_value
self.total_trades += 1
reward = -0.01 # Small penalty for transaction
elif action == 2 and self.position_size > 0: # Sell (increase position)
additional_value = self.balance * 0.05 * self.risk_multiplier
additional_size = additional_value / current_price
self.position_size += additional_size
self.balance -= additional_value
self.total_trades += 1
reward = -0.005
elif action == 3 and self.position_size > 0: # Close position
close_value = self.position_size * current_price
self.balance += close_value
if self.entry_price > 0:
profit_loss = (current_price - self.entry_price) / self.entry_price
reward = profit_loss * 10 # Scale profit/loss
self.position_size = 0.0
self.entry_price = 0.0
self.total_trades += 1
# Update net worth
position_value = self.position_size * current_price if self.position_size > 0 else 0.0
self.net_worth = self.balance + position_value
# Add small reward for portfolio growth
if prev_net_worth > 0:
portfolio_change = (self.net_worth - prev_net_worth) / prev_net_worth
reward += portfolio_change * 5
# Move to next step
self.current_step += 1
# Check if episode is done
if self.current_step >= len(self.price_data) - 1:
self.done = True
# Final reward based on overall performance
if self.initial_balance > 0:
final_return = (self.net_worth - self.initial_balance) / self.initial_balance
reward += final_return * 20
info = {
'net_worth': float(self.net_worth),
'balance': float(self.balance),
'position_size': float(self.position_size),
'current_price': float(current_price),
'total_trades': int(self.total_trades),
'step': int(self.current_step)
}
obs = self._get_visual_observation()
return obs, float(reward), bool(self.done), info
except Exception as e:
print(f"Error in step execution: {e}")
# Return safe default values in case of error
default_info = {
'net_worth': float(self.initial_balance),
'balance': float(self.initial_balance),
'position_size': 0.0,
'current_price': 100.0,
'total_trades': 0,
'step': int(self.current_step)
}
return self._get_visual_observation(), 0.0, True, default_info
def get_price_history(self):
"""Get recent price history for visualization"""
window_size = min(50, self.current_step)
start_idx = max(0, self.current_step - window_size)
prices = self.price_data[start_idx:self.current_step]
return [float(price) for price in prices]