| | import argparse |
| | import subprocess |
| | import wandb |
| | import wandb.apis.public |
| |
|
| | from collections import defaultdict |
| | from multiprocessing.pool import ThreadPool |
| | from typing import List, NamedTuple |
| |
|
| |
|
| | class RunGroup(NamedTuple): |
| | algo: str |
| | env_id: str |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--wandb-project-name", |
| | type=str, |
| | default="rl-algo-impls-benchmarks", |
| | help="WandB project name to load runs from", |
| | ) |
| | parser.add_argument( |
| | "--wandb-entity", |
| | type=str, |
| | default=None, |
| | help="WandB team of project. None uses default entity", |
| | ) |
| | parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags") |
| | parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report") |
| | parser.add_argument( |
| | "--envs", type=str, nargs="*", help="Optional filter down to these envs" |
| | ) |
| | parser.add_argument( |
| | "--exclude-envs", |
| | type=str, |
| | nargs="*", |
| | help="Environments to exclude from publishing", |
| | ) |
| | parser.add_argument( |
| | "--huggingface-user", |
| | type=str, |
| | default=None, |
| | help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user", |
| | ) |
| | parser.add_argument( |
| | "--pool-size", |
| | type=int, |
| | default=3, |
| | help="How many publish jobs can run in parallel", |
| | ) |
| | parser.add_argument( |
| | "--virtual-display", action="store_true", help="Use headless virtual display" |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | args = parser.parse_args() |
| | print(args) |
| |
|
| | api = wandb.Api() |
| | all_runs = api.runs( |
| | f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}" |
| | ) |
| |
|
| | required_tags = set(args.wandb_tags) |
| | runs: List[wandb.apis.public.Run] = [ |
| | r |
| | for r in all_runs |
| | if required_tags.issubset(set(r.config.get("wandb_tags", []))) |
| | ] |
| |
|
| | runs_paths_by_group = defaultdict(list) |
| | for r in runs: |
| | if r.state != "finished": |
| | continue |
| | algo = r.config["algo"] |
| | env = r.config["env"] |
| | if args.envs and env not in args.envs: |
| | continue |
| | if args.exclude_envs and env in args.exclude_envs: |
| | continue |
| | run_group = RunGroup(algo, env) |
| | runs_paths_by_group[run_group].append("/".join(r.path)) |
| |
|
| | def run(run_paths: List[str]) -> None: |
| | publish_args = ["python", "huggingface_publish.py"] |
| | publish_args.append("--wandb-run-paths") |
| | publish_args.extend(run_paths) |
| | publish_args.append("--wandb-report-url") |
| | publish_args.append(args.wandb_report_url) |
| | if args.huggingface_user: |
| | publish_args.append("--huggingface-user") |
| | publish_args.append(args.huggingface_user) |
| | if args.virtual_display: |
| | publish_args.append("--virtual-display") |
| | subprocess.run(publish_args) |
| |
|
| | tp = ThreadPool(args.pool_size) |
| | for run_paths in runs_paths_by_group.values(): |
| | tp.apply_async(run, (run_paths,)) |
| | tp.close() |
| | tp.join() |
| |
|