|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
A unified tracking interface that supports logging data to different backend
|
|
|
"""
|
|
|
|
|
|
import dataclasses
|
|
|
from enum import Enum
|
|
|
from functools import partial
|
|
|
from pathlib import Path
|
|
|
from typing import Any, Dict, List, Union
|
|
|
|
|
|
|
|
|
class Tracking:
|
|
|
supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console"]
|
|
|
|
|
|
def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None):
|
|
|
if isinstance(default_backend, str):
|
|
|
default_backend = [default_backend]
|
|
|
for backend in default_backend:
|
|
|
if backend == "tracking":
|
|
|
import warnings
|
|
|
|
|
|
warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning, stacklevel=2)
|
|
|
else:
|
|
|
assert backend in self.supported_backend, f"{backend} is not supported"
|
|
|
|
|
|
self.logger = {}
|
|
|
|
|
|
if "tracking" in default_backend or "wandb" in default_backend:
|
|
|
import wandb
|
|
|
|
|
|
wandb.init(project=project_name, name=experiment_name, config=config)
|
|
|
self.logger["wandb"] = wandb
|
|
|
|
|
|
if "mlflow" in default_backend:
|
|
|
import os
|
|
|
|
|
|
import mlflow
|
|
|
|
|
|
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", None)
|
|
|
if MLFLOW_TRACKING_URI:
|
|
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
|
|
|
|
|
|
|
|
|
|
|
experiment = mlflow.set_experiment(project_name)
|
|
|
mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)
|
|
|
mlflow.log_params(_compute_mlflow_params_from_objects(config))
|
|
|
self.logger["mlflow"] = _MlflowLoggingAdapter()
|
|
|
|
|
|
if "swanlab" in default_backend:
|
|
|
import os
|
|
|
|
|
|
import swanlab
|
|
|
|
|
|
SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None)
|
|
|
SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog")
|
|
|
SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud")
|
|
|
if SWANLAB_API_KEY:
|
|
|
swanlab.login(SWANLAB_API_KEY)
|
|
|
|
|
|
if config is None:
|
|
|
config = {}
|
|
|
swanlab.init(
|
|
|
project=project_name,
|
|
|
experiment_name=experiment_name,
|
|
|
config={"FRAMEWORK": "verl", **config},
|
|
|
logdir=SWANLAB_LOG_DIR,
|
|
|
mode=SWANLAB_MODE,
|
|
|
)
|
|
|
self.logger["swanlab"] = swanlab
|
|
|
|
|
|
if "vemlp_wandb" in default_backend:
|
|
|
import os
|
|
|
|
|
|
import volcengine_ml_platform
|
|
|
from volcengine_ml_platform import wandb as vemlp_wandb
|
|
|
|
|
|
volcengine_ml_platform.init(
|
|
|
ak=os.environ["VOLC_ACCESS_KEY_ID"],
|
|
|
sk=os.environ["VOLC_SECRET_ACCESS_KEY"],
|
|
|
region=os.environ["MLP_TRACKING_REGION"],
|
|
|
)
|
|
|
|
|
|
vemlp_wandb.init(
|
|
|
project=project_name,
|
|
|
name=experiment_name,
|
|
|
config=config,
|
|
|
sync_tensorboard=True,
|
|
|
)
|
|
|
self.logger["vemlp_wandb"] = vemlp_wandb
|
|
|
|
|
|
if "tensorboard" in default_backend:
|
|
|
self.logger["tensorboard"] = _TensorboardAdapter()
|
|
|
|
|
|
if "console" in default_backend:
|
|
|
from verl.utils.logger.aggregate_logger import LocalLogger
|
|
|
|
|
|
self.console_logger = LocalLogger(print_to_console=True)
|
|
|
self.logger["console"] = self.console_logger
|
|
|
|
|
|
def log(self, data, step, backend=None):
|
|
|
for default_backend, logger_instance in self.logger.items():
|
|
|
if backend is None or default_backend in backend:
|
|
|
logger_instance.log(data=data, step=step)
|
|
|
|
|
|
def __del__(self):
|
|
|
if "wandb" in self.logger:
|
|
|
self.logger["wandb"].finish(exit_code=0)
|
|
|
if "swanlab" in self.logger:
|
|
|
self.logger["swanlab"].finish()
|
|
|
if "vemlp_wandb" in self.logger:
|
|
|
self.logger["vemlp_wandb"].finish(exit_code=0)
|
|
|
if "tensorboard" in self.logger:
|
|
|
self.logger["tensorboard"].finish()
|
|
|
|
|
|
|
|
|
class _TensorboardAdapter:
|
|
|
def __init__(self):
|
|
|
import os
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "tensorboard_log")
|
|
|
os.makedirs(tensorboard_dir, exist_ok=True)
|
|
|
print(f"Saving tensorboard log to {tensorboard_dir}.")
|
|
|
self.writer = SummaryWriter(tensorboard_dir)
|
|
|
|
|
|
def log(self, data, step):
|
|
|
for key in data:
|
|
|
self.writer.add_scalar(key, data[key], step)
|
|
|
|
|
|
def finish(self):
|
|
|
self.writer.close()
|
|
|
|
|
|
|
|
|
class _MlflowLoggingAdapter:
|
|
|
def log(self, data, step):
|
|
|
import mlflow
|
|
|
|
|
|
results = {k.replace("@", "_at_"): v for k, v in data.items()}
|
|
|
mlflow.log_metrics(metrics=results, step=step)
|
|
|
|
|
|
|
|
|
def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]:
|
|
|
if params is None:
|
|
|
return {}
|
|
|
|
|
|
return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/")
|
|
|
|
|
|
|
|
|
def _transform_params_to_json_serializable(x, convert_list_to_dict: bool):
|
|
|
_transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)
|
|
|
|
|
|
if dataclasses.is_dataclass(x):
|
|
|
return _transform(dataclasses.asdict(x))
|
|
|
if isinstance(x, dict):
|
|
|
return {k: _transform(v) for k, v in x.items()}
|
|
|
if isinstance(x, list):
|
|
|
if convert_list_to_dict:
|
|
|
return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)}
|
|
|
else:
|
|
|
return [_transform(v) for v in x]
|
|
|
if isinstance(x, Path):
|
|
|
return str(x)
|
|
|
if isinstance(x, Enum):
|
|
|
return x.value
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:
|
|
|
import pandas as pd
|
|
|
|
|
|
ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0]
|
|
|
assert isinstance(ans, dict)
|
|
|
return ans
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
class ValidationGenerationsLogger:
|
|
|
def log(self, loggers, samples, step):
|
|
|
if "wandb" in loggers:
|
|
|
self.log_generations_to_wandb(samples, step)
|
|
|
if "swanlab" in loggers:
|
|
|
self.log_generations_to_swanlab(samples, step)
|
|
|
if "mlflow" in loggers:
|
|
|
self.log_generations_to_mlflow(samples, step)
|
|
|
|
|
|
def log_generations_to_wandb(self, samples, step):
|
|
|
"""Log samples to wandb as a table"""
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
columns = ["step"] + sum([[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [])
|
|
|
|
|
|
if not hasattr(self, "validation_table"):
|
|
|
|
|
|
self.validation_table = wandb.Table(columns=columns)
|
|
|
|
|
|
|
|
|
|
|
|
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
|
|
|
|
|
|
|
|
|
row_data = []
|
|
|
row_data.append(step)
|
|
|
for sample in samples:
|
|
|
row_data.extend(sample)
|
|
|
|
|
|
new_table.add_data(*row_data)
|
|
|
|
|
|
|
|
|
wandb.log({"val/generations": new_table}, step=step)
|
|
|
self.validation_table = new_table
|
|
|
|
|
|
def log_generations_to_swanlab(self, samples, step):
|
|
|
"""Log samples to swanlab as text"""
|
|
|
import swanlab
|
|
|
|
|
|
swanlab_text_list = []
|
|
|
for i, sample in enumerate(samples):
|
|
|
row_text = f"""
|
|
|
input: {sample[0]}
|
|
|
|
|
|
---
|
|
|
|
|
|
output: {sample[1]}
|
|
|
|
|
|
---
|
|
|
|
|
|
score: {sample[2]}
|
|
|
"""
|
|
|
swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
|
|
|
|
|
|
|
|
|
swanlab.log({"val/generations": swanlab_text_list}, step=step)
|
|
|
|
|
|
def log_generations_to_mlflow(self, samples, step):
|
|
|
"""Log validation generation to mlflow as artifacts"""
|
|
|
|
|
|
|
|
|
import json
|
|
|
import tempfile
|
|
|
|
|
|
import mlflow
|
|
|
|
|
|
try:
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
|
validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json")
|
|
|
row_data = []
|
|
|
for sample in samples:
|
|
|
data = {"input": sample[0], "output": sample[1], "score": sample[2]}
|
|
|
row_data.append(data)
|
|
|
with open(validation_gen_step_file, "w") as file:
|
|
|
json.dump(row_data, file)
|
|
|
mlflow.log_artifact(validation_gen_step_file)
|
|
|
except Exception as e:
|
|
|
print(f"WARNING: save validation generation file to mlflow failed with error {e}")
|
|
|
|