| from typing import Union, List | |
| from verl.utils.tracking import Tracking | |
| class ReasonRLTracking(Tracking): | |
| def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None, resume='never', run_id=None, tags: List[str] = None): | |
| if isinstance(default_backend, str): | |
| default_backend = [default_backend] | |
| for backend in default_backend: | |
| if backend == 'tracking': | |
| import warnings | |
| warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning) | |
| else: | |
| assert backend in self.supported_backend, f'{backend} is not supported' | |
| self.logger = {} | |
| if 'tracking' in default_backend or 'wandb' in default_backend: | |
| import wandb | |
| wandb_kwargs = {} | |
| if resume == 'must': | |
| wandb_kwargs = {'resume': 'must', 'id': run_id} | |
| elif resume == 'allow': | |
| wandb_kwargs = {'resume': 'allow', 'id': run_id} | |
| if tags is not None: | |
| wandb_kwargs['tags'] = tags | |
| run = wandb.init(project=project_name, settings=wandb.Settings(start_method="thread"), name=experiment_name, config=config, **wandb_kwargs) | |
| self.run_id = run.id | |
| self.logger['wandb'] = wandb | |
| if 'console' in default_backend: | |
| from verl.utils.logger.aggregate_logger import LocalLogger | |
| self.console_logger = LocalLogger(print_to_console=True) | |
| self.logger['console'] = self.console_logger | |