| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| import functools |
| import inspect |
| import os |
| from contextvars import ContextVar |
| from typing import Optional |
|
|
| from pydantic import BaseModel |
|
|
| from verl.utils.ray_utils import get_event_loop |
|
|
| _trace_enabled: ContextVar[bool] = ContextVar("_trace_enabled", default=True) |
|
|
|
|
| class RolloutTraceConfig: |
| """Configuration for rollout tracing with various backends. |
| |
| Singleton configuration class for managing rollout trace settings across different |
| tracing backends like Weave and MLflow. |
| |
| Args: |
| backend (Optional[str]): Tracing backend to use ('weave', 'mlflow', or None). |
| client (Optional[object]): Client instance for the selected backend. |
| token2text (bool): Whether to convert tokens to text in traces. Defaults to False. |
| project_name (str): Name of the project for tracing. |
| experiment_name (str): Name of the experiment for tracing. |
| max_samples_per_step_per_worker (Optional[int]): Maximum number of unique samples to trace |
| per worker per step. If None, all samples are traced. If set, each worker will randomly |
| select up to this many unique samples to trace (including all their rollouts for GRPO). |
| Total traces = max_samples_per_step_per_worker * num_workers * n_rollouts_per_sample. |
| """ |
|
|
| _instance: Optional["RolloutTraceConfig"] = None |
| backend: Optional[str] = None |
| client: Optional[object] = None |
| token2text: bool = False |
| _initialized: bool = False |
| project_name: str = None |
| experiment_name: str = None |
| max_samples_per_step_per_worker: Optional[int] = None |
|
|
| def __new__(cls, *args, **kwargs): |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| cls._instance._initialized = False |
| return cls._instance |
|
|
| @classmethod |
| def get_instance(cls) -> "RolloutTraceConfig": |
| if cls._instance is None: |
| cls._instance = cls() |
| return cls._instance |
|
|
| @classmethod |
| def init( |
| cls, |
| project_name: str, |
| experiment_name: str, |
| backend: str, |
| token2text: bool = False, |
| max_samples_per_step_per_worker: Optional[int] = None, |
| ): |
| config = cls.get_instance() |
| if config._initialized: |
| return |
|
|
| config.backend = backend |
| config.token2text = token2text |
| config.project_name = project_name |
| config.experiment_name = experiment_name |
| config.max_samples_per_step_per_worker = max_samples_per_step_per_worker |
|
|
| if backend == "weave": |
| import weave |
|
|
| config.client = weave.init(project_name) |
| elif backend == "mlflow": |
| import mlflow |
|
|
| mlflow.config.enable_async_logging() |
| config.client = mlflow |
|
|
| MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") |
| mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) |
|
|
| mlflow.set_experiment(project_name) |
| else: |
| config.client = None |
|
|
| config._initialized = True |
|
|
| @classmethod |
| def get_backend(cls) -> Optional[str]: |
| return cls.get_instance().backend |
|
|
| @classmethod |
| def get_client(cls) -> Optional[object]: |
| return cls.get_instance().client |
|
|
| @classmethod |
| def enable_token2text(cls) -> Optional[bool]: |
| return cls.get_instance().token2text |
|
|
| @classmethod |
| def reset(cls): |
| cls._instance = None |
|
|
|
|
| @contextlib.contextmanager |
| def rollout_trace_attr( |
| sample_index=None, step=None, rollout_n=None, name="rollout_trace", validate=False, trace: bool = True |
| ): |
| """A context manager to add attributes to a trace for the configured backend. |
| |
| Args: |
| sample_index: Sample index for the trace. |
| step: Training step number. |
| rollout_n: Rollout number (for GRPO with multiple rollouts per sample). |
| name: Name for the trace span (used by mlflow backend). |
| validate: Whether this is a validation run. |
| trace: If False, disables tracing for the duration of the context. |
| """ |
| backend = RolloutTraceConfig.get_backend() |
|
|
| should_skip = backend is not None and not trace |
|
|
| if should_skip: |
| token = _trace_enabled.set(False) |
| try: |
| yield |
| finally: |
| _trace_enabled.reset(token) |
| return |
|
|
| |
| attributes = {} |
| if backend: |
| if sample_index is not None: |
| attributes["sample_index"] = sample_index |
| if step is not None: |
| attributes["step"] = step |
| if rollout_n is not None: |
| attributes["rollout_n"] = rollout_n |
| attributes["validate"] = validate |
| attributes["experiment_name"] = RolloutTraceConfig.get_instance().experiment_name |
|
|
| if not attributes or backend is None: |
| yield |
| return |
|
|
| if backend == "weave": |
| import weave |
|
|
| with weave.attributes(attributes): |
| yield |
| elif backend == "mlflow": |
| import mlflow |
|
|
| with mlflow.start_span(name=name) as span: |
| trace_id = span.trace_id |
| for key, value in attributes.items(): |
| mlflow.set_trace_tag(trace_id, str(key), str(value)) |
| yield |
| else: |
| yield |
|
|
|
|
| def rollout_trace_op(func): |
| @functools.wraps(func) |
| async def async_wrapper(self, *args, **kwargs): |
| if not _trace_enabled.get(): |
| return await func(self, *args, **kwargs) |
|
|
| backend = RolloutTraceConfig.get_backend() |
| enable_token2text = RolloutTraceConfig.enable_token2text() |
| if backend is None: |
| return await func(self, *args, **kwargs) |
|
|
| sig = inspect.signature(func) |
| bound_args = sig.bind(self, *args, **kwargs) |
| bound_args.apply_defaults() |
| inputs = dict(bound_args.arguments) |
| del inputs["self"] |
|
|
| async def add_token2text(self, result): |
| if hasattr(result, "prompt_ids") and hasattr(self, "tokenizer") and hasattr(self.tokenizer, "decode"): |
| |
| |
| |
| if isinstance(result, BaseModel): |
| _result = result.model_dump() |
| else: |
| _result = dict(vars(result)) |
| loop = get_event_loop() |
| if hasattr(result, "prompt_ids"): |
| prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids) |
| _result["prompt_text"] = prompt_text |
|
|
| if hasattr(result, "response_ids"): |
| response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids) |
| _result["response_text"] = response_text |
| return _result |
| return result |
|
|
| if backend == "weave": |
| tracer = RolloutTraceConfig.get_client() |
| from weave.trace.context import call_context |
|
|
| cur_attributes = {**call_context.call_attributes.get()} |
| call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) |
| try: |
| result = await func(self, *args, **kwargs) |
|
|
| if enable_token2text: |
| _result = await add_token2text(self, result) |
| tracer.finish_call(call, output=_result) |
| else: |
| tracer.finish_call(call, output=result) |
|
|
| return result |
|
|
| except Exception as e: |
| tracer.finish_call(call, exception=e) |
| raise e |
| elif backend == "mlflow": |
| import mlflow |
|
|
| with mlflow.start_span(name=func.__qualname__) as span: |
| span.set_inputs(inputs) |
| result = await func(self, *args, **kwargs) |
| if enable_token2text: |
| _result = await add_token2text(self, result) |
| span.set_outputs(_result) |
| else: |
| span.set_outputs(result) |
|
|
| return result |
|
|
| else: |
| return await func(self, *args, **kwargs) |
|
|
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| if not _trace_enabled.get(): |
| return func(self, *args, **kwargs) |
|
|
| backend = RolloutTraceConfig.get_backend() |
| if backend is None: |
| return func(self, *args, **kwargs) |
|
|
| sig = inspect.signature(func) |
| bound_args = sig.bind(self, *args, **kwargs) |
| bound_args.apply_defaults() |
| inputs = dict(bound_args.arguments) |
| del inputs["self"] |
|
|
| if backend == "weave": |
| tracer = RolloutTraceConfig.get_client() |
| from weave.trace.context import call_context |
|
|
| cur_attributes = {**call_context.call_attributes.get()} |
| call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) |
| try: |
| result = func(self, *args, **kwargs) |
| tracer.finish_call(call, output=result) |
| return result |
| except Exception as e: |
| tracer.finish_call(call, exception=e) |
| raise e |
| elif backend == "mlflow": |
| import mlflow |
|
|
| return mlflow.trace(func)(self, *args, **kwargs) |
| else: |
| return func(self, *args, **kwargs) |
|
|
| return async_wrapper if inspect.iscoroutinefunction(func) else wrapper |
|
|