"""Verbose debug logging for the OPSD / TriMode training pipeline.""" from __future__ import annotations import json import os import time import traceback from contextlib import contextmanager from typing import Any, Optional _DEBUG_ENABLED = False _DETAIL_EVERY = 10 _PROBE_ON_GENERATE = False _PROBE_FIRST_TOKEN_LOGITS = True _PROBE_PROMPT_TAIL_TOKENS = 16 _PROBE_LOG_MODEL_CONTEXT = True _HEALTH_MONITOR_ENABLED = True _HEALTH_LOG_ON_GENERATE = True _HEALTH_LOG_EVERY_STEP = True _HEALTH_LOG_DETAIL_BUNDLE = True _HEALTH_LOG_ALERTS_IMMEDIATELY = True _RANK = 0 _WORLD_SIZE = 1 _STEP_LABEL = "init" _DETAIL_STEP: Optional[int] = None _CALL_COUNTER = 0 MODE_NAMES = {0: "GRPO", 1: "OPSD", 2: "SFT"} def _env_debug_enabled() -> bool: return os.environ.get("DYME_OPSD_DEBUG", "").strip().lower() in ("1", "true", "yes", "on") def _env_detail_every() -> int: raw = os.environ.get("DYME_OPSD_DETAIL_EVERY", "").strip() if not raw: return 10 try: return max(0, int(raw)) except ValueError: return 10 def _env_probe_on_generate() -> Optional[bool]: raw = os.environ.get("DYME_OPSD_PROBE_ON_GENERATE", "").strip().lower() if not raw: return None return raw in ("1", "true", "yes", "on") def _env_probe_first_token_logits() -> Optional[bool]: raw = os.environ.get("DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS", "").strip().lower() if not raw: return None return raw in ("1", "true", "yes", "on") def _env_probe_prompt_tail_tokens() -> Optional[int]: raw = os.environ.get("DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS", "").strip() if not raw: return None try: return max(1, int(raw)) except ValueError: return 16 def _env_probe_log_model_context() -> Optional[bool]: raw = os.environ.get("DYME_OPSD_PROBE_LOG_MODEL_CONTEXT", "").strip().lower() if not raw: return None return raw in ("1", "true", "yes", "on") def _env_health_monitor_enabled() -> Optional[bool]: raw = os.environ.get("DYME_OPSD_HEALTH_MONITOR", "").strip().lower() if not raw: return None return raw in ("1", "true", "yes", "on") def configure( *, enabled: Optional[bool] = None, detail_every: Optional[int] = None, probe_on_generate: Optional[bool] = None, probe_first_token_logits: Optional[bool] = None, probe_prompt_tail_tokens: Optional[int] = None, probe_log_model_context: Optional[bool] = None, health_monitor_enabled: Optional[bool] = None, health_log_on_generate: Optional[bool] = None, health_log_every_step: Optional[bool] = None, health_log_detail_bundle: Optional[bool] = None, health_log_alerts_immediately: Optional[bool] = None, rank: Optional[int] = None, world_size: Optional[int] = None, ) -> bool: """Configure global OPSD debug logging. Returns whether debug is enabled.""" global _DEBUG_ENABLED, _DETAIL_EVERY, _PROBE_ON_GENERATE global _PROBE_FIRST_TOKEN_LOGITS, _PROBE_PROMPT_TAIL_TOKENS, _PROBE_LOG_MODEL_CONTEXT global _HEALTH_MONITOR_ENABLED, _HEALTH_LOG_ON_GENERATE, _HEALTH_LOG_EVERY_STEP global _HEALTH_LOG_DETAIL_BUNDLE, _HEALTH_LOG_ALERTS_IMMEDIATELY global _RANK, _WORLD_SIZE if enabled is None: enabled = _env_debug_enabled() _DEBUG_ENABLED = bool(enabled) if detail_every is not None: _DETAIL_EVERY = max(0, int(detail_every)) elif _env_detail_every() != 10 or os.environ.get("DYME_OPSD_DETAIL_EVERY"): _DETAIL_EVERY = _env_detail_every() env_probe = _env_probe_on_generate() if probe_on_generate is not None: _PROBE_ON_GENERATE = bool(probe_on_generate) elif env_probe is not None: _PROBE_ON_GENERATE = env_probe env_first_logits = _env_probe_first_token_logits() if probe_first_token_logits is not None: _PROBE_FIRST_TOKEN_LOGITS = bool(probe_first_token_logits) elif env_first_logits is not None: _PROBE_FIRST_TOKEN_LOGITS = env_first_logits env_tail = _env_probe_prompt_tail_tokens() if probe_prompt_tail_tokens is not None: _PROBE_PROMPT_TAIL_TOKENS = max(1, int(probe_prompt_tail_tokens)) elif env_tail is not None: _PROBE_PROMPT_TAIL_TOKENS = env_tail env_model_ctx = _env_probe_log_model_context() if probe_log_model_context is not None: _PROBE_LOG_MODEL_CONTEXT = bool(probe_log_model_context) elif env_model_ctx is not None: _PROBE_LOG_MODEL_CONTEXT = env_model_ctx env_health = _env_health_monitor_enabled() if health_monitor_enabled is not None: _HEALTH_MONITOR_ENABLED = bool(health_monitor_enabled) elif env_health is not None: _HEALTH_MONITOR_ENABLED = env_health if health_log_on_generate is not None: _HEALTH_LOG_ON_GENERATE = bool(health_log_on_generate) if health_log_every_step is not None: _HEALTH_LOG_EVERY_STEP = bool(health_log_every_step) if health_log_detail_bundle is not None: _HEALTH_LOG_DETAIL_BUNDLE = bool(health_log_detail_bundle) if health_log_alerts_immediately is not None: _HEALTH_LOG_ALERTS_IMMEDIATELY = bool(health_log_alerts_immediately) if rank is not None: _RANK = rank if world_size is not None: _WORLD_SIZE = world_size return _DEBUG_ENABLED def detail_every() -> int: return _DETAIL_EVERY def probe_on_generate() -> bool: return _PROBE_ON_GENERATE def probe_first_token_logits() -> bool: return _PROBE_FIRST_TOKEN_LOGITS def probe_prompt_tail_tokens() -> int: return _PROBE_PROMPT_TAIL_TOKENS def probe_log_model_context() -> bool: return _PROBE_LOG_MODEL_CONTEXT def should_log_probe() -> bool: """True when lightweight per-generate probe should run (rank 0 only).""" return _PROBE_ON_GENERATE and _RANK == 0 def should_log_detail(global_step: Optional[int]) -> bool: """True when a full diagnostic bundle should be emitted (rank 0 only).""" if _DETAIL_EVERY <= 0 or _RANK != 0: return False if global_step is None: return False return int(global_step) % _DETAIL_EVERY == 0 def is_enabled() -> bool: return _DEBUG_ENABLED def set_step_label(label: str) -> None: global _STEP_LABEL _STEP_LABEL = label def set_detail_step(global_step: Optional[int]) -> None: global _DETAIL_STEP _DETAIL_STEP = global_step def get_detail_step() -> Optional[int]: return _DETAIL_STEP def _next_call_id(stage: str) -> str: global _CALL_COUNTER _CALL_COUNTER += 1 return f"{_CALL_COUNTER}:{stage}" def _fmt(value: Any, max_len: int = 240) -> str: if value is None: return "None" try: import torch if isinstance(value, torch.Tensor): return ( f"Tensor(shape={tuple(value.shape)}, dtype={value.dtype}, " f"device={value.device}, numel={value.numel()})" ) except ImportError: pass if isinstance(value, (list, tuple)): if len(value) > 12: head = ", ".join(_fmt(v, max_len=40) for v in value[:6]) return f"[{head}, ... +{len(value) - 6} more, total={len(value)}]" return "[" + ", ".join(_fmt(v, max_len=40) for v in value) + "]" if isinstance(value, dict): text = json.dumps(value, ensure_ascii=False, default=str) else: text = repr(value) if len(text) > max_len: return text[: max_len - 3] + "..." return text def _prefix(stage: str, call_id: Optional[str] = None) -> str: cid = call_id or _next_call_id(stage) ts = time.strftime("%Y-%m-%d %H:%M:%S") return f"[OPSD-DEBUG][{ts}][rank={_RANK}/{_WORLD_SIZE}][step={_STEP_LABEL}][{cid}]" def log(stage: str, msg: str, **fields: Any) -> None: if not _DEBUG_ENABLED: return call_id = _next_call_id(stage) extra = "" if fields: extra = " | " + " | ".join(f"{k}={_fmt(v)}" for k, v in fields.items()) print(f"{_prefix(stage, call_id)} {msg}{extra}", flush=True) def _detail_prefix(global_step: int, section: str) -> str: ts = time.strftime("%Y-%m-%d %H:%M:%S") return ( f"[OPSD-DETAIL][{ts}][rank={_RANK}/{_WORLD_SIZE}]" f"[step={global_step}][every={_DETAIL_EVERY}][{section}]" ) def log_detail_banner(global_step: int, title: str) -> None: if not should_log_detail(global_step): return bar = "=" * 20 print(f"{_detail_prefix(global_step, 'BANNER')} {bar} {title} {bar}", flush=True) def log_detail(section: str, msg: str, global_step: Optional[int] = None, **fields: Any) -> None: """Full-detail diagnostic line (periodic, rank 0). Independent of verbose OPSD-DEBUG.""" step = global_step if global_step is not None else _DETAIL_STEP if step is None or isinstance(step, str): return if not should_log_detail(step): return extra = "" if fields: extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=800)}" for k, v in fields.items()) print(f"{_detail_prefix(step, section)} {msg}{extra}", flush=True) def _probe_prefix(section: str) -> str: ts = time.strftime("%Y-%m-%d %H:%M:%S") step = _DETAIL_STEP if _DETAIL_STEP is not None else "?" return ( f"[OPSD-PROBE][{ts}][rank={_RANK}/{_WORLD_SIZE}]" f"[global_step={step}][{_STEP_LABEL}][{section}]" ) def log_probe(section: str, msg: str, **fields: Any) -> None: """Lightweight per-generate diagnostic (rank 0). Independent of OPSD-DEBUG verbosity.""" if not should_log_probe(): return extra = "" if fields: extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items()) print(f"{_probe_prefix(section)} {msg}{extra}", flush=True) def _gendbg_prefix(section: str) -> str: ts = time.strftime("%Y-%m-%d %H:%M:%S") step = _DETAIL_STEP if _DETAIL_STEP is not None else "?" return ( f"[OPSD-GENDBG][{ts}][rank={_RANK}/{_WORLD_SIZE}]" f"[global_step={step}][{_STEP_LABEL}][{section}]" ) def should_log_gendbg() -> bool: """True when deep generate diagnostics should run (rank 0 only).""" return _PROBE_ON_GENERATE and _RANK == 0 def health_monitor_enabled() -> bool: return _HEALTH_MONITOR_ENABLED and _RANK == 0 def should_log_health_on_generate() -> bool: return health_monitor_enabled() and _HEALTH_LOG_ON_GENERATE and should_log_probe() def should_log_health_every_step() -> bool: return health_monitor_enabled() and _HEALTH_LOG_EVERY_STEP def should_log_health_detail_bundle() -> bool: return health_monitor_enabled() and _HEALTH_LOG_DETAIL_BUNDLE def should_log_health_alerts_immediately() -> bool: return health_monitor_enabled() and _HEALTH_LOG_ALERTS_IMMEDIATELY def _health_prefix(section: str, global_step: Optional[int] = None) -> str: ts = time.strftime("%Y-%m-%d %H:%M:%S") step = global_step if global_step is not None else (_DETAIL_STEP if _DETAIL_STEP is not None else "?") return f"[OPSD-HEALTH][{ts}][rank={_RANK}/{_WORLD_SIZE}][global_step={step}][{section}]" def log_health(section: str, msg: str, global_step: Optional[int] = None, **fields: Any) -> None: """L1/L2/L4 health lines (rank 0).""" if not health_monitor_enabled(): return extra = "" if fields: extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items()) print(f"{_health_prefix(section, global_step)} {msg}{extra}", flush=True) def log_health_detail_banner(global_step: int, title: str) -> None: if not should_log_health_detail_bundle() or not should_log_detail(global_step): return bar = "=" * 20 print(f"{_detail_prefix(global_step, 'health')} {bar} {title} {bar}", flush=True) def log_health_detail(section: str, msg: str, global_step: int, **fields: Any) -> None: """L3 periodic health bundle (rank 0, same cadence as OPSD-DETAIL).""" if not should_log_health_detail_bundle() or not should_log_detail(global_step): return extra = "" if fields: extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items()) print(f"{_detail_prefix(global_step, section)} {msg}{extra}", flush=True) def log_gendbg(section: str, msg: str, **fields: Any) -> None: """Deep per-generate diagnostic (rank 0). Uses [OPSD-GENDBG] prefix.""" if not should_log_gendbg(): return extra = "" if fields: extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items()) print(f"{_gendbg_prefix(section)} {msg}{extra}", flush=True) def log_config(stage: str, title: str, config: dict[str, Any]) -> None: if not _DEBUG_ENABLED: return log(stage, title, config=_fmt(config, max_len=2000)) def log_sync_point(stage: str, msg: str, **fields: Any) -> None: """Mark a point where all ranks must reach before a collective call.""" log(stage, f"[SYNC_POINT] {msg}", **fields) def log_mode_summary(stage: str, prompt_modes: list[int], completion_modes: Optional[list[int]] = None) -> None: if not _DEBUG_ENABLED: return prompt_names = [MODE_NAMES.get(m, str(m)) for m in prompt_modes] fields: dict[str, Any] = {"prompt_modes": prompt_names} if completion_modes is not None: counts = {name: completion_modes.count(code) for code, name in MODE_NAMES.items()} fields["completion_mode_counts"] = counts log(stage, "mode routing summary", **fields) def log_tensor(stage: str, name: str, tensor: Any) -> None: if not _DEBUG_ENABLED: return log(stage, f"tensor `{name}`", value=_fmt(tensor)) def log_exception(stage: str, msg: str, exc: BaseException) -> None: if not _DEBUG_ENABLED: return tb = traceback.format_exc() log(stage, msg, error=repr(exc), traceback=_fmt(tb, max_len=4000)) @contextmanager def timed(stage: str, msg: str = "", **fields: Any): if not _DEBUG_ENABLED: yield return t0 = time.perf_counter() log(stage, f"START {msg}", **fields) try: yield except Exception as exc: log_exception(stage, f"FAILED {msg}", exc) raise finally: elapsed = time.perf_counter() - t0 log(stage, f"END {msg}", elapsed_s=f"{elapsed:.4f}")