Spaces:
Sleeping
Sleeping
Update src/environments/visual_trading_env.py
Browse files
src/environments/visual_trading_env.py
CHANGED
|
@@ -46,60 +46,87 @@ class VisualTradingEnvironment:
|
|
| 46 |
|
| 47 |
def _get_visual_observation(self):
|
| 48 |
"""Generate visual representation of current market state"""
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
end_idx
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def reset(self):
|
| 105 |
"""Reset environment to initial state"""
|
|
@@ -174,9 +201,11 @@ class VisualTradingEnvironment:
|
|
| 174 |
'step': self.current_step
|
| 175 |
}
|
| 176 |
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
def get_price_history(self):
|
| 180 |
"""Get recent price history for visualization"""
|
| 181 |
window_size = min(50, self.current_step)
|
| 182 |
-
|
|
|
|
|
|
| 46 |
|
| 47 |
def _get_visual_observation(self):
|
| 48 |
"""Generate visual representation of current market state"""
|
| 49 |
+
try:
|
| 50 |
+
# Get recent price window
|
| 51 |
+
window_size = 50
|
| 52 |
+
start_idx = max(0, self.current_step - window_size)
|
| 53 |
+
end_idx = self.current_step + 1
|
| 54 |
+
|
| 55 |
+
if end_idx > len(self.price_data):
|
| 56 |
+
end_idx = len(self.price_data)
|
| 57 |
+
|
| 58 |
+
prices = self.price_data[start_idx:end_idx]
|
| 59 |
+
|
| 60 |
+
# Create matplotlib figure with fixed size
|
| 61 |
+
fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
|
| 62 |
+
ax.set_facecolor('black')
|
| 63 |
+
|
| 64 |
+
# Plot price
|
| 65 |
+
if len(prices) > 0:
|
| 66 |
+
ax.plot(prices, color='cyan', linewidth=1.5)
|
| 67 |
+
|
| 68 |
+
# Remove axes for cleaner visual
|
| 69 |
+
ax.set_xticks([])
|
| 70 |
+
ax.set_yticks([])
|
| 71 |
+
ax.spines['top'].set_visible(False)
|
| 72 |
+
ax.spines['right'].set_visible(False)
|
| 73 |
+
ax.spines['bottom'].set_visible(False)
|
| 74 |
+
ax.spines['left'].set_visible(False)
|
| 75 |
+
|
| 76 |
+
# Set fixed limits to ensure consistent size
|
| 77 |
+
ax.set_xlim(0, 50)
|
| 78 |
+
if len(prices) > 0:
|
| 79 |
+
price_min, price_max = min(prices), max(prices)
|
| 80 |
+
price_range = price_max - price_min
|
| 81 |
+
if price_range == 0:
|
| 82 |
+
price_range = 1
|
| 83 |
+
ax.set_ylim(price_min - price_range * 0.1, price_max + price_range * 0.1)
|
| 84 |
+
else:
|
| 85 |
+
ax.set_ylim(0, 100)
|
| 86 |
+
|
| 87 |
+
# Convert to numpy array with consistent size
|
| 88 |
+
buf = io.BytesIO()
|
| 89 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, facecolor='black', dpi=20)
|
| 90 |
+
buf.seek(0)
|
| 91 |
+
img = Image.open(buf).convert('RGB')
|
| 92 |
+
|
| 93 |
+
# Resize to consistent dimensions
|
| 94 |
+
img = img.resize((84, 84), Image.Resampling.LANCZOS)
|
| 95 |
+
img_array = np.array(img)
|
| 96 |
+
|
| 97 |
+
plt.close(fig)
|
| 98 |
+
|
| 99 |
+
# Create attention map with same dimensions
|
| 100 |
+
attention_map = np.zeros((84, 84))
|
| 101 |
+
if len(prices) > 1:
|
| 102 |
+
recent_change = (prices[-1] - prices[-2]) / prices[-2]
|
| 103 |
+
intensity = min(255, abs(recent_change) * 5000)
|
| 104 |
+
|
| 105 |
+
# Simple attention based on price movement
|
| 106 |
+
center_x, center_y = 42, 42
|
| 107 |
+
size = max(5, int(intensity / 50))
|
| 108 |
+
|
| 109 |
+
for i in range(max(0, center_x-size), min(84, center_x+size)):
|
| 110 |
+
for j in range(max(0, center_y-size), min(84, center_y+size)):
|
| 111 |
+
distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
|
| 112 |
+
if distance <= size:
|
| 113 |
+
attention_map[i, j] = max(attention_map[i, j], intensity * (1 - distance/size))
|
| 114 |
+
|
| 115 |
+
# Ensure both arrays have same shape before concatenation
|
| 116 |
+
attention_map = attention_map.astype(np.uint8)
|
| 117 |
+
|
| 118 |
+
# Combine RGB with attention map - now both are 84x84
|
| 119 |
+
visual_obs = np.concatenate([
|
| 120 |
+
img_array,
|
| 121 |
+
attention_map[:, :, np.newaxis] # Add channel dimension
|
| 122 |
+
], axis=2)
|
| 123 |
+
|
| 124 |
+
return visual_obs
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Error in visual observation: {e}")
|
| 128 |
+
# Return default observation in case of error
|
| 129 |
+
return np.zeros((84, 84, 4), dtype=np.uint8)
|
| 130 |
|
| 131 |
def reset(self):
|
| 132 |
"""Reset environment to initial state"""
|
|
|
|
| 201 |
'step': self.current_step
|
| 202 |
}
|
| 203 |
|
| 204 |
+
obs = self._get_visual_observation()
|
| 205 |
+
return obs, reward, self.done, info
|
| 206 |
|
| 207 |
def get_price_history(self):
|
| 208 |
"""Get recent price history for visualization"""
|
| 209 |
window_size = min(50, self.current_step)
|
| 210 |
+
start_idx = max(0, self.current_step - window_size)
|
| 211 |
+
return self.price_data[start_idx:self.current_step].tolist()
|