OmidSakaki commited on
Commit
208b262
·
verified ·
1 Parent(s): 5a78d94

Update src/environments/visual_trading_env.py

Browse files
src/environments/visual_trading_env.py CHANGED
@@ -46,60 +46,87 @@ class VisualTradingEnvironment:
46
 
47
  def _get_visual_observation(self):
48
  """Generate visual representation of current market state"""
49
- # Get recent price window
50
- window_size = 50
51
- start_idx = max(0, self.current_step - window_size)
52
- end_idx = self.current_step + 1
53
-
54
- if end_idx > len(self.price_data):
55
- end_idx = len(self.price_data)
56
-
57
- prices = self.price_data[start_idx:end_idx]
58
-
59
- # Create matplotlib figure
60
- fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
61
- ax.set_facecolor('black')
62
-
63
- # Plot price
64
- ax.plot(prices, color='cyan', linewidth=1.5)
65
-
66
- # Remove axes for cleaner visual
67
- ax.set_xticks([])
68
- ax.set_yticks([])
69
- ax.spines['top'].set_visible(False)
70
- ax.spines['right'].set_visible(False)
71
- ax.spines['bottom'].set_visible(False)
72
- ax.spines['left'].set_visible(False)
73
-
74
- # Convert to numpy array
75
- buf = io.BytesIO()
76
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, facecolor='black')
77
- buf.seek(0)
78
- img = Image.open(buf).convert('RGB')
79
- img_array = np.array(img)
80
-
81
- plt.close(fig)
82
-
83
- # Create simple attention map (in real implementation, this would come from CNN)
84
- attention_map = np.zeros((84, 84))
85
- if len(prices) > 1:
86
- recent_change = (prices[-1] - prices[-2]) / prices[-2]
87
- intensity = min(255, abs(recent_change) * 5000)
88
-
89
- # Simple attention based on price movement
90
- center_x, center_y = 42, 42
91
- size = max(5, int(intensity / 50))
92
-
93
- for i in range(max(0, center_x-size), min(84, center_x+size)):
94
- for j in range(max(0, center_y-size), min(84, center_y+size)):
95
- distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
96
- if distance <= size:
97
- attention_map[i, j] = max(attention_map[i, j], intensity * (1 - distance/size))
98
-
99
- # Combine RGB with attention map
100
- visual_obs = np.concatenate([img_array, attention_map[:, :, np.newaxis]], axis=2)
101
-
102
- return visual_obs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def reset(self):
105
  """Reset environment to initial state"""
@@ -174,9 +201,11 @@ class VisualTradingEnvironment:
174
  'step': self.current_step
175
  }
176
 
177
- return self._get_visual_observation(), reward, self.done, info
 
178
 
179
  def get_price_history(self):
180
  """Get recent price history for visualization"""
181
  window_size = min(50, self.current_step)
182
- return self.price_data[self.current_step - window_size:self.current_step].tolist()
 
 
46
 
47
  def _get_visual_observation(self):
48
  """Generate visual representation of current market state"""
49
+ try:
50
+ # Get recent price window
51
+ window_size = 50
52
+ start_idx = max(0, self.current_step - window_size)
53
+ end_idx = self.current_step + 1
54
+
55
+ if end_idx > len(self.price_data):
56
+ end_idx = len(self.price_data)
57
+
58
+ prices = self.price_data[start_idx:end_idx]
59
+
60
+ # Create matplotlib figure with fixed size
61
+ fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
62
+ ax.set_facecolor('black')
63
+
64
+ # Plot price
65
+ if len(prices) > 0:
66
+ ax.plot(prices, color='cyan', linewidth=1.5)
67
+
68
+ # Remove axes for cleaner visual
69
+ ax.set_xticks([])
70
+ ax.set_yticks([])
71
+ ax.spines['top'].set_visible(False)
72
+ ax.spines['right'].set_visible(False)
73
+ ax.spines['bottom'].set_visible(False)
74
+ ax.spines['left'].set_visible(False)
75
+
76
+ # Set fixed limits to ensure consistent size
77
+ ax.set_xlim(0, 50)
78
+ if len(prices) > 0:
79
+ price_min, price_max = min(prices), max(prices)
80
+ price_range = price_max - price_min
81
+ if price_range == 0:
82
+ price_range = 1
83
+ ax.set_ylim(price_min - price_range * 0.1, price_max + price_range * 0.1)
84
+ else:
85
+ ax.set_ylim(0, 100)
86
+
87
+ # Convert to numpy array with consistent size
88
+ buf = io.BytesIO()
89
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, facecolor='black', dpi=20)
90
+ buf.seek(0)
91
+ img = Image.open(buf).convert('RGB')
92
+
93
+ # Resize to consistent dimensions
94
+ img = img.resize((84, 84), Image.Resampling.LANCZOS)
95
+ img_array = np.array(img)
96
+
97
+ plt.close(fig)
98
+
99
+ # Create attention map with same dimensions
100
+ attention_map = np.zeros((84, 84))
101
+ if len(prices) > 1:
102
+ recent_change = (prices[-1] - prices[-2]) / prices[-2]
103
+ intensity = min(255, abs(recent_change) * 5000)
104
+
105
+ # Simple attention based on price movement
106
+ center_x, center_y = 42, 42
107
+ size = max(5, int(intensity / 50))
108
+
109
+ for i in range(max(0, center_x-size), min(84, center_x+size)):
110
+ for j in range(max(0, center_y-size), min(84, center_y+size)):
111
+ distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
112
+ if distance <= size:
113
+ attention_map[i, j] = max(attention_map[i, j], intensity * (1 - distance/size))
114
+
115
+ # Ensure both arrays have same shape before concatenation
116
+ attention_map = attention_map.astype(np.uint8)
117
+
118
+ # Combine RGB with attention map - now both are 84x84
119
+ visual_obs = np.concatenate([
120
+ img_array,
121
+ attention_map[:, :, np.newaxis] # Add channel dimension
122
+ ], axis=2)
123
+
124
+ return visual_obs
125
+
126
+ except Exception as e:
127
+ print(f"Error in visual observation: {e}")
128
+ # Return default observation in case of error
129
+ return np.zeros((84, 84, 4), dtype=np.uint8)
130
 
131
  def reset(self):
132
  """Reset environment to initial state"""
 
201
  'step': self.current_step
202
  }
203
 
204
+ obs = self._get_visual_observation()
205
+ return obs, reward, self.done, info
206
 
207
  def get_price_history(self):
208
  """Get recent price history for visualization"""
209
  window_size = min(50, self.current_step)
210
+ start_idx = max(0, self.current_step - window_size)
211
+ return self.price_data[start_idx:self.current_step].tolist()