from typing import ( Any, Dict, List, Tuple, ) import gym import numpy as np class SuccessCounter(gym.Wrapper): """Helper class to keep count of successes in MetaWorld environments.""" def __init__(self, env: gym.Env) -> None: super().__init__(env) self.successes = [] self.current_success = False def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]: obs, reward, done, info = self.env.step(action) if info.get("success", False): self.current_success = True if done: self.successes.append(self.current_success) return obs, reward, done, info def pop_successes(self) -> List[bool]: res = self.successes self.successes = [] return res def reset(self, **kwargs) -> np.ndarray: self.current_success = False return self.env.reset(**kwargs)