File size: 437 Bytes
13bec41 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | from stable_baselines3 import A2C
from agent.observation_wrapper import CustomObsWrapper
class Agent:
def __init__(self, env) -> None:
self.model = A2C.load("agent/my_model")
self.observation_wrapper = CustomObsWrapper(env)
def act(self, observation):
extended_obsetvation = self.observation_wrapper.observation(observation)
return self.model.predict(extended_obsetvation, deterministic=True)
|