|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
A unified tracking interface that supports logging data to different backend |
|
|
""" |
|
|
|
|
|
import os |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict, filter_config_for_hparams |
|
|
from .gen_logger import AggregateGenerationsLogger |
|
|
|
|
|
|
|
|
if is_package_available("mlflow"): |
|
|
import mlflow |
|
|
|
|
|
|
|
|
if is_package_available("wandb"): |
|
|
import wandb |
|
|
|
|
|
|
|
|
if is_package_available("swanlab"): |
|
|
import swanlab |
|
|
|
|
|
|
|
|
class Logger(ABC): |
|
|
@abstractmethod |
|
|
def __init__(self, config: Dict[str, Any]) -> None: ... |
|
|
|
|
|
@abstractmethod |
|
|
def log(self, data: Dict[str, Any], step: int) -> None: ... |
|
|
|
|
|
def finish(self) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
class ConsoleLogger(Logger): |
|
|
def __init__(self, config: Dict[str, Any]) -> None: |
|
|
print("Config\n" + convert_dict_to_str(config)) |
|
|
|
|
|
def log(self, data: Dict[str, Any], step: int) -> None: |
|
|
print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data))) |
|
|
|
|
|
|
|
|
class MlflowLogger(Logger): |
|
|
def __init__(self, config: Dict[str, Any]) -> None: |
|
|
mlflow.start_run(run_name=config["trainer"]["experiment_name"]) |
|
|
mlflow.log_params(flatten_dict(config)) |
|
|
|
|
|
def log(self, data: Dict[str, Any], step: int) -> None: |
|
|
mlflow.log_metrics(metrics=data, step=step) |
|
|
|
|
|
|
|
|
class TensorBoardLogger(Logger): |
|
|
def __init__(self, config: Dict[str, Any]) -> None: |
|
|
tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log") |
|
|
os.makedirs(tensorboard_dir, exist_ok=True) |
|
|
print(f"Saving tensorboard log to {tensorboard_dir}.") |
|
|
self.writer = SummaryWriter(tensorboard_dir) |
|
|
filtered_config = filter_config_for_hparams(config) |
|
|
self.writer.add_hparams(flatten_dict(filtered_config), {}) |
|
|
|
|
|
def log(self, data: Dict[str, Any], step: int) -> None: |
|
|
for key, value in data.items(): |
|
|
self.writer.add_scalar(key, value, step) |
|
|
|
|
|
def finish(self): |
|
|
self.writer.close() |
|
|
|
|
|
|
|
|
class WandbLogger(Logger): |
|
|
def __init__(self, config: Dict[str, Any]) -> None: |
|
|
wandb.init( |
|
|
project=config["trainer"]["project_name"], |
|
|
name=config["trainer"]["experiment_name"], |
|
|
config=config, |
|
|
) |
|
|
|
|
|
def log(self, data: Dict[str, Any], step: int) -> None: |
|
|
wandb.log(data=data, step=step) |
|
|
|
|
|
def finish(self) -> None: |
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
class SwanlabLogger(Logger): |
|
|
def __init__(self, config: Dict[str, Any]) -> None: |
|
|
swanlab_key = os.getenv("SWANLAB_API_KEY") |
|
|
swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log") |
|
|
swanlab_mode = os.getenv("SWANLAB_MODE", "cloud") |
|
|
if swanlab_key: |
|
|
swanlab.login(swanlab_key) |
|
|
|
|
|
swanlab.init( |
|
|
project=config["trainer"]["project_name"], |
|
|
experiment_name=config["trainer"]["experiment_name"], |
|
|
config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config}, |
|
|
logdir=swanlab_dir, |
|
|
mode=swanlab_mode, |
|
|
) |
|
|
|
|
|
def log(self, data: Dict[str, Any], step: int) -> None: |
|
|
swanlab.log(data=data, step=step) |
|
|
|
|
|
def finish(self) -> None: |
|
|
swanlab.finish() |
|
|
|
|
|
|
|
|
LOGGERS = { |
|
|
"wandb": WandbLogger, |
|
|
"mlflow": MlflowLogger, |
|
|
"tensorboard": TensorBoardLogger, |
|
|
"console": ConsoleLogger, |
|
|
"swanlab": SwanlabLogger, |
|
|
} |
|
|
|
|
|
|
|
|
class Tracker: |
|
|
def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None): |
|
|
if isinstance(loggers, str): |
|
|
loggers = [loggers] |
|
|
|
|
|
self.loggers: List[Logger] = [] |
|
|
for logger in loggers: |
|
|
if logger not in LOGGERS: |
|
|
raise ValueError(f"{logger} is not supported.") |
|
|
|
|
|
self.loggers.append(LOGGERS[logger](config)) |
|
|
|
|
|
self.gen_logger = AggregateGenerationsLogger(loggers) |
|
|
|
|
|
def log(self, data: Dict[str, Any], step: int) -> None: |
|
|
for logger in self.loggers: |
|
|
logger.log(data=data, step=step) |
|
|
|
|
|
def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: |
|
|
self.gen_logger.log(samples, step) |
|
|
|
|
|
def __del__(self): |
|
|
for logger in self.loggers: |
|
|
logger.finish() |
|
|
|