Spaces:
Paused
Paused
File size: 1,354 Bytes
74aaccc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import random
from models import Observation, Action, Reward
class EcommerceEnv:
def __init__(self):
self.sessions = [
{"history": [1,2], "target": 3},
{"history": [2,3], "target": 1},
{"history": [1,3], "target": 2},
]
self.max_steps = 5
def reset(self):
self.current_session = random.choice(self.sessions)
self.steps = 0
self.done = False
return self._get_obs()
def _get_obs(self):
return Observation(
user_id=1,
history=self.current_session["history"]
)
def step(self, action: Action):
self.steps += 1
target = self.current_session["target"]
# Reward shaping (IMPORTANT)
if action.recommended_product == target:
reward = 1.0 # purchase
self.done = True
elif action.recommended_product in self.current_session["history"]:
reward = 0.3 # click
else:
reward = -0.2 # irrelevant
if self.steps >= self.max_steps:
self.done = True
return self._get_obs(), Reward(score=reward), self.done, {}
def state(self):
return {
"steps": self.steps,
"target": self.current_session["target"]
} |