Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -78,30 +78,41 @@ def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
|
|
| 78 |
return LayerNormalization(epsilon=1e-6)(x + res)
|
| 79 |
|
| 80 |
# RL Environment
|
| 81 |
-
class PricingEnv(
|
| 82 |
def __init__(self, data):
|
| 83 |
super(PricingEnv, self).__init__()
|
| 84 |
self.data = data
|
| 85 |
self.current_step = 0
|
|
|
|
| 86 |
self.action_space = spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32)
|
| 87 |
self.observation_space = spaces.Box(low=0, high=np.inf, shape=(6,), dtype=np.float32)
|
| 88 |
|
| 89 |
def step(self, action):
|
| 90 |
reward = self._get_reward(action)
|
| 91 |
self.current_step += 1
|
| 92 |
-
done = self.current_step >=
|
| 93 |
obs = self._get_observation()
|
| 94 |
-
return obs, reward, done, {}
|
| 95 |
|
| 96 |
-
def reset(self):
|
|
|
|
| 97 |
self.current_step = 0
|
| 98 |
-
return self._get_observation()
|
| 99 |
|
| 100 |
def _get_observation(self):
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def _get_reward(self, action):
|
|
|
|
|
|
|
|
|
|
| 105 |
price = action[0]
|
| 106 |
actual_price = self.data.iloc[self.current_step]['price']
|
| 107 |
return -abs(price - actual_price)
|
|
|
|
| 78 |
return LayerNormalization(epsilon=1e-6)(x + res)
|
| 79 |
|
| 80 |
# RL Environment
|
| 81 |
+
class PricingEnv(Env):
|
| 82 |
def __init__(self, data):
|
| 83 |
super(PricingEnv, self).__init__()
|
| 84 |
self.data = data
|
| 85 |
self.current_step = 0
|
| 86 |
+
self.max_steps = len(data) - 1
|
| 87 |
self.action_space = spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32)
|
| 88 |
self.observation_space = spaces.Box(low=0, high=np.inf, shape=(6,), dtype=np.float32)
|
| 89 |
|
| 90 |
def step(self, action):
|
| 91 |
reward = self._get_reward(action)
|
| 92 |
self.current_step += 1
|
| 93 |
+
done = self.current_step >= self.max_steps
|
| 94 |
obs = self._get_observation()
|
| 95 |
+
return obs, reward, done, False, {} # Added False for truncated flag
|
| 96 |
|
| 97 |
+
def reset(self, seed=None, options=None):
|
| 98 |
+
super().reset(seed=seed)
|
| 99 |
self.current_step = 0
|
| 100 |
+
return self._get_observation(), {} # Return observation and info dict
|
| 101 |
|
| 102 |
def _get_observation(self):
|
| 103 |
+
if self.current_step > self.max_steps:
|
| 104 |
+
# If we've gone past the end of the data, return the last valid observation
|
| 105 |
+
step = self.max_steps
|
| 106 |
+
else:
|
| 107 |
+
step = self.current_step
|
| 108 |
+
|
| 109 |
+
obs = self.data.iloc[step][['demand_index', 'competitor_price', 'past_sales', 'genre_encoded', 'region_encoded']].values
|
| 110 |
+
return np.append(obs, step)
|
| 111 |
|
| 112 |
def _get_reward(self, action):
|
| 113 |
+
if self.current_step > self.max_steps:
|
| 114 |
+
return 0 # Or some other appropriate value for going out of bounds
|
| 115 |
+
|
| 116 |
price = action[0]
|
| 117 |
actual_price = self.data.iloc[self.current_step]['price']
|
| 118 |
return -abs(price - actual_price)
|