Spaces:
Sleeping
Sleeping
| import fnmatch | |
| import os | |
| import random | |
| import time | |
| import pybullet_envs_gymnasium # noqa: F401 pylint: disable=unused-import | |
| from datasets import load_dataset | |
| from huggingface_hub import HfApi | |
| from src.evaluation import evaluate | |
| from src.logging import setup_logger | |
| logger = setup_logger(__name__) | |
| API = HfApi(token=os.environ.get("TOKEN")) | |
| RESULTS_REPO = "open-rl-leaderboard/results_v2" | |
| ALL_ENV_IDS = [ | |
| "AdventureNoFrameskip-v4", | |
| "AirRaidNoFrameskip-v4", | |
| "AlienNoFrameskip-v4", | |
| "AmidarNoFrameskip-v4", | |
| "AssaultNoFrameskip-v4", | |
| "AsterixNoFrameskip-v4", | |
| "AsteroidsNoFrameskip-v4", | |
| "AtlantisNoFrameskip-v4", | |
| "BankHeistNoFrameskip-v4", | |
| "BattleZoneNoFrameskip-v4", | |
| "BeamRiderNoFrameskip-v4", | |
| "BerzerkNoFrameskip-v4", | |
| "BowlingNoFrameskip-v4", | |
| "BoxingNoFrameskip-v4", | |
| "BreakoutNoFrameskip-v4", | |
| "CarnivalNoFrameskip-v4", | |
| "CentipedeNoFrameskip-v4", | |
| "ChopperCommandNoFrameskip-v4", | |
| "CrazyClimberNoFrameskip-v4", | |
| "DefenderNoFrameskip-v4", | |
| "DemonAttackNoFrameskip-v4", | |
| "DoubleDunkNoFrameskip-v4", | |
| "ElevatorActionNoFrameskip-v4", | |
| "EnduroNoFrameskip-v4", | |
| "FishingDerbyNoFrameskip-v4", | |
| "FreewayNoFrameskip-v4", | |
| "FrostbiteNoFrameskip-v4", | |
| "GopherNoFrameskip-v4", | |
| "GravitarNoFrameskip-v4", | |
| "HeroNoFrameskip-v4", | |
| "IceHockeyNoFrameskip-v4", | |
| "JamesbondNoFrameskip-v4", | |
| "JourneyEscapeNoFrameskip-v4", | |
| "KangarooNoFrameskip-v4", | |
| "KrullNoFrameskip-v4", | |
| "KungFuMasterNoFrameskip-v4", | |
| "MontezumaRevengeNoFrameskip-v4", | |
| "MsPacmanNoFrameskip-v4", | |
| "NameThisGameNoFrameskip-v4", | |
| "PhoenixNoFrameskip-v4", | |
| "PitfallNoFrameskip-v4", | |
| "PongNoFrameskip-v4", | |
| "PooyanNoFrameskip-v4", | |
| "PrivateEyeNoFrameskip-v4", | |
| "QbertNoFrameskip-v4", | |
| "RiverraidNoFrameskip-v4", | |
| "RoadRunnerNoFrameskip-v4", | |
| "RobotankNoFrameskip-v4", | |
| "SeaquestNoFrameskip-v4", | |
| "SkiingNoFrameskip-v4", | |
| "SolarisNoFrameskip-v4", | |
| "SpaceInvadersNoFrameskip-v4", | |
| "StarGunnerNoFrameskip-v4", | |
| "TennisNoFrameskip-v4", | |
| "TimePilotNoFrameskip-v4", | |
| "TutankhamNoFrameskip-v4", | |
| "UpNDownNoFrameskip-v4", | |
| "VentureNoFrameskip-v4", | |
| "VideoPinballNoFrameskip-v4", | |
| "WizardOfWorNoFrameskip-v4", | |
| "YarsRevengeNoFrameskip-v4", | |
| "ZaxxonNoFrameskip-v4", | |
| # Box2D | |
| "BipedalWalker-v3", | |
| "BipedalWalkerHardcore-v3", | |
| "CarRacing-v2", | |
| "LunarLander-v2", | |
| "LunarLanderContinuous-v2", | |
| # Toy text | |
| "Blackjack-v1", | |
| "CliffWalking-v0", | |
| "FrozenLake-v1", | |
| "FrozenLake8x8-v1", | |
| # Classic control | |
| "Acrobot-v1", | |
| "CartPole-v1", | |
| "MountainCar-v0", | |
| "MountainCarContinuous-v0", | |
| "Pendulum-v1", | |
| # MuJoCo | |
| "Ant-v4", | |
| "HalfCheetah-v4", | |
| "Hopper-v4", | |
| "Humanoid-v4", | |
| "HumanoidStandup-v4", | |
| "InvertedDoublePendulum-v4", | |
| "InvertedPendulum-v4", | |
| "Pusher-v4", | |
| "Reacher-v4", | |
| "Swimmer-v4", | |
| "Walker2d-v4", | |
| # PyBullet | |
| "AntBulletEnv-v0", | |
| "HalfCheetahBulletEnv-v0", | |
| "HopperBulletEnv-v0", | |
| "HumanoidBulletEnv-v0", | |
| "InvertedDoublePendulumBulletEnv-v0", | |
| "InvertedPendulumSwingupBulletEnv-v0", | |
| "MinitaurBulletEnv-v0", | |
| "ReacherBulletEnv-v0", | |
| "Walker2DBulletEnv-v0", | |
| ] | |
| def pattern_match(patterns, source_list): | |
| if isinstance(patterns, str): | |
| patterns = [patterns] | |
| env_ids = set() | |
| for pattern in patterns: | |
| for matching in fnmatch.filter(source_list, pattern): | |
| env_ids.add(matching) | |
| return sorted(list(env_ids)) | |
| def _backend_routine(): | |
| # List only the text classification models | |
| rl_models = list(API.list_models(filter=["reinforcement-learning"])) | |
| logger.info(f"Found {len(rl_models)} RL models") | |
| compatible_models = [] | |
| for model in rl_models: | |
| filenames = [sib.rfilename for sib in model.siblings] | |
| if "agent.pt" in filenames: | |
| compatible_models.append((model.modelId, model.sha)) | |
| logger.info(f"Found {len(compatible_models)} compatible models") | |
| dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks") | |
| evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset] | |
| pending_models = list(set(compatible_models) - set(evaluated_models)) | |
| logger.info(f"Found {len(pending_models)} pending models") | |
| if len(pending_models) == 0: | |
| return None | |
| # Shuffle the dataset | |
| random.shuffle(pending_models) | |
| # Select a random model | |
| repo_id, sha = pending_models.pop() | |
| user_id, model_id = repo_id.split("/") | |
| row = {"model_id": model_id, "user_id": user_id, "sha": sha} | |
| # Run an evaluation on the models | |
| model_info = API.model_info(repo_id, revision=sha) | |
| # Extract the environment IDs from the tags (usually only one) | |
| env_ids = pattern_match(model_info.tags, ALL_ENV_IDS) | |
| if len(env_ids) > 0: | |
| env_id = env_ids[0] | |
| logger.info(f"Running evaluation on {user_id}/{model_id}") | |
| try: | |
| episodic_returns = evaluate(repo_id, sha, env_id) | |
| row["status"] = "DONE" | |
| row["env_id"] = env_id | |
| row["episodic_returns"] = episodic_returns | |
| except Exception as e: | |
| logger.error(f"Error evaluating {repo_id}: {e}") | |
| logger.exception(e) | |
| row["status"] = "FAILED" | |
| else: | |
| logger.error(f"No environment found for {model_id}") | |
| row["status"] = "FAILED" | |
| # load the last version of the dataset | |
| dataset = load_dataset( | |
| RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks" | |
| ) | |
| dataset.add_item(row) | |
| dataset.push_to_hub(RESULTS_REPO, split="train", token=API.token) | |
| time.sleep(60) # Sleep for 1 minute to avoid rate limiting | |
| def backend_routine(): | |
| try: | |
| _backend_routine() | |
| except Exception as e: | |
| logger.error(f"{e.__class__.__name__}: {str(e)}") | |
| logger.exception(e) | |
| if __name__ == "__main__": | |
| backend_routine() | |