Spaces:
Running on Zero
Running on Zero
| from dotenv import load_dotenv | |
| # For debugging purposes apply dotenv file | |
| load_dotenv() | |
| import asyncio | |
| import art | |
| import hydra | |
| import torch | |
| import wandb | |
| from omegaconf import DictConfig, OmegaConf | |
| from linalg_zero.grpo.run_rl import test | |
| from linalg_zero.grpo.types import LinAlgPolicyConfig | |
| def main(cfg: DictConfig) -> None: | |
| # Convert all configs to plain dicts | |
| init_config = OmegaConf.to_container(cfg.init, resolve=True) | |
| training_config = OmegaConf.to_container(cfg.training, resolve=True) | |
| run_config = OmegaConf.to_container(cfg.run, resolve=True) | |
| trainer_args = OmegaConf.to_container(cfg.trainer, resolve=True) | |
| engine_args = OmegaConf.to_container(cfg.engine, resolve=True) | |
| print(f"Evaluating model {cfg.run.base_model}") | |
| assert isinstance(init_config, dict), "Init config must be provided" | |
| assert isinstance(training_config, dict), "Training config must be provided" | |
| assert isinstance(run_config, dict), "Run config must be provided" | |
| assert isinstance(trainer_args, dict), "Trainer args must be provided" | |
| assert isinstance(engine_args, dict), "Engine args must be provided" | |
| # Set dynamic values | |
| if "tensor_parallel_size" not in engine_args: | |
| engine_args["tensor_parallel_size"] = torch.cuda.device_count() | |
| report_to = trainer_args.get("report_to") if isinstance(trainer_args, dict) else None | |
| if report_to: | |
| if isinstance(report_to, str): | |
| report_to = [report_to] | |
| if "wandb" in report_to and wandb.run is None: | |
| wandb.init(project=run_config["project"], name=run_config["project_id"], job_type="eval") | |
| # Build model and run training | |
| model = art.TrainableModel( | |
| name=run_config["project_id"], | |
| project=run_config["project"], | |
| base_model=run_config["base_model"], | |
| config=LinAlgPolicyConfig( | |
| training_config=training_config, | |
| run_config=run_config, | |
| ), | |
| _internal_config=art.dev.InternalModelConfig( | |
| init_args=init_config, | |
| engine_args=engine_args, | |
| trainer_args=trainer_args, | |
| ), | |
| ) | |
| try: | |
| asyncio.run(test(model)) | |
| finally: | |
| if wandb.run is not None: | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() | |