Apple / apple /wrappers.py
New Author Name
init
4b714e2
raw
history blame contribute delete
914 Bytes
from typing import (
Any,
Dict,
List,
Tuple,
)
import gym
import numpy as np
class SuccessCounter(gym.Wrapper):
"""Helper class to keep count of successes in MetaWorld environments."""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.successes = []
self.current_success = False
def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
obs, reward, done, info = self.env.step(action)
if info.get("success", False):
self.current_success = True
if done:
self.successes.append(self.current_success)
return obs, reward, done, info
def pop_successes(self) -> List[bool]:
res = self.successes
self.successes = []
return res
def reset(self, **kwargs) -> np.ndarray:
self.current_success = False
return self.env.reset(**kwargs)