TetrisAI / agent.py
marci0929's picture
Upload with huggingface_hub
13bec41
raw
history blame contribute delete
437 Bytes
from stable_baselines3 import A2C
from agent.observation_wrapper import CustomObsWrapper
class Agent:
def __init__(self, env) -> None:
self.model = A2C.load("agent/my_model")
self.observation_wrapper = CustomObsWrapper(env)
def act(self, observation):
extended_obsetvation = self.observation_wrapper.observation(observation)
return self.model.predict(extended_obsetvation, deterministic=True)