agentic-rl-main / opsd_utils /debug_log.py
Jack04810's picture
Add files using upload-large-folder tool
36d0b76 verified
Raw
History Blame Contribute Delete
14.2 kB
"""Verbose debug logging for the OPSD / TriMode training pipeline."""
from __future__ import annotations
import json
import os
import time
import traceback
from contextlib import contextmanager
from typing import Any, Optional
_DEBUG_ENABLED = False
_DETAIL_EVERY = 10
_PROBE_ON_GENERATE = False
_PROBE_FIRST_TOKEN_LOGITS = True
_PROBE_PROMPT_TAIL_TOKENS = 16
_PROBE_LOG_MODEL_CONTEXT = True
_HEALTH_MONITOR_ENABLED = True
_HEALTH_LOG_ON_GENERATE = True
_HEALTH_LOG_EVERY_STEP = True
_HEALTH_LOG_DETAIL_BUNDLE = True
_HEALTH_LOG_ALERTS_IMMEDIATELY = True
_RANK = 0
_WORLD_SIZE = 1
_STEP_LABEL = "init"
_DETAIL_STEP: Optional[int] = None
_CALL_COUNTER = 0
MODE_NAMES = {0: "GRPO", 1: "OPSD", 2: "SFT"}
def _env_debug_enabled() -> bool:
return os.environ.get("DYME_OPSD_DEBUG", "").strip().lower() in ("1", "true", "yes", "on")
def _env_detail_every() -> int:
raw = os.environ.get("DYME_OPSD_DETAIL_EVERY", "").strip()
if not raw:
return 10
try:
return max(0, int(raw))
except ValueError:
return 10
def _env_probe_on_generate() -> Optional[bool]:
raw = os.environ.get("DYME_OPSD_PROBE_ON_GENERATE", "").strip().lower()
if not raw:
return None
return raw in ("1", "true", "yes", "on")
def _env_probe_first_token_logits() -> Optional[bool]:
raw = os.environ.get("DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS", "").strip().lower()
if not raw:
return None
return raw in ("1", "true", "yes", "on")
def _env_probe_prompt_tail_tokens() -> Optional[int]:
raw = os.environ.get("DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS", "").strip()
if not raw:
return None
try:
return max(1, int(raw))
except ValueError:
return 16
def _env_probe_log_model_context() -> Optional[bool]:
raw = os.environ.get("DYME_OPSD_PROBE_LOG_MODEL_CONTEXT", "").strip().lower()
if not raw:
return None
return raw in ("1", "true", "yes", "on")
def _env_health_monitor_enabled() -> Optional[bool]:
raw = os.environ.get("DYME_OPSD_HEALTH_MONITOR", "").strip().lower()
if not raw:
return None
return raw in ("1", "true", "yes", "on")
def configure(
*,
enabled: Optional[bool] = None,
detail_every: Optional[int] = None,
probe_on_generate: Optional[bool] = None,
probe_first_token_logits: Optional[bool] = None,
probe_prompt_tail_tokens: Optional[int] = None,
probe_log_model_context: Optional[bool] = None,
health_monitor_enabled: Optional[bool] = None,
health_log_on_generate: Optional[bool] = None,
health_log_every_step: Optional[bool] = None,
health_log_detail_bundle: Optional[bool] = None,
health_log_alerts_immediately: Optional[bool] = None,
rank: Optional[int] = None,
world_size: Optional[int] = None,
) -> bool:
"""Configure global OPSD debug logging. Returns whether debug is enabled."""
global _DEBUG_ENABLED, _DETAIL_EVERY, _PROBE_ON_GENERATE
global _PROBE_FIRST_TOKEN_LOGITS, _PROBE_PROMPT_TAIL_TOKENS, _PROBE_LOG_MODEL_CONTEXT
global _HEALTH_MONITOR_ENABLED, _HEALTH_LOG_ON_GENERATE, _HEALTH_LOG_EVERY_STEP
global _HEALTH_LOG_DETAIL_BUNDLE, _HEALTH_LOG_ALERTS_IMMEDIATELY
global _RANK, _WORLD_SIZE
if enabled is None:
enabled = _env_debug_enabled()
_DEBUG_ENABLED = bool(enabled)
if detail_every is not None:
_DETAIL_EVERY = max(0, int(detail_every))
elif _env_detail_every() != 10 or os.environ.get("DYME_OPSD_DETAIL_EVERY"):
_DETAIL_EVERY = _env_detail_every()
env_probe = _env_probe_on_generate()
if probe_on_generate is not None:
_PROBE_ON_GENERATE = bool(probe_on_generate)
elif env_probe is not None:
_PROBE_ON_GENERATE = env_probe
env_first_logits = _env_probe_first_token_logits()
if probe_first_token_logits is not None:
_PROBE_FIRST_TOKEN_LOGITS = bool(probe_first_token_logits)
elif env_first_logits is not None:
_PROBE_FIRST_TOKEN_LOGITS = env_first_logits
env_tail = _env_probe_prompt_tail_tokens()
if probe_prompt_tail_tokens is not None:
_PROBE_PROMPT_TAIL_TOKENS = max(1, int(probe_prompt_tail_tokens))
elif env_tail is not None:
_PROBE_PROMPT_TAIL_TOKENS = env_tail
env_model_ctx = _env_probe_log_model_context()
if probe_log_model_context is not None:
_PROBE_LOG_MODEL_CONTEXT = bool(probe_log_model_context)
elif env_model_ctx is not None:
_PROBE_LOG_MODEL_CONTEXT = env_model_ctx
env_health = _env_health_monitor_enabled()
if health_monitor_enabled is not None:
_HEALTH_MONITOR_ENABLED = bool(health_monitor_enabled)
elif env_health is not None:
_HEALTH_MONITOR_ENABLED = env_health
if health_log_on_generate is not None:
_HEALTH_LOG_ON_GENERATE = bool(health_log_on_generate)
if health_log_every_step is not None:
_HEALTH_LOG_EVERY_STEP = bool(health_log_every_step)
if health_log_detail_bundle is not None:
_HEALTH_LOG_DETAIL_BUNDLE = bool(health_log_detail_bundle)
if health_log_alerts_immediately is not None:
_HEALTH_LOG_ALERTS_IMMEDIATELY = bool(health_log_alerts_immediately)
if rank is not None:
_RANK = rank
if world_size is not None:
_WORLD_SIZE = world_size
return _DEBUG_ENABLED
def detail_every() -> int:
return _DETAIL_EVERY
def probe_on_generate() -> bool:
return _PROBE_ON_GENERATE
def probe_first_token_logits() -> bool:
return _PROBE_FIRST_TOKEN_LOGITS
def probe_prompt_tail_tokens() -> int:
return _PROBE_PROMPT_TAIL_TOKENS
def probe_log_model_context() -> bool:
return _PROBE_LOG_MODEL_CONTEXT
def should_log_probe() -> bool:
"""True when lightweight per-generate probe should run (rank 0 only)."""
return _PROBE_ON_GENERATE and _RANK == 0
def should_log_detail(global_step: Optional[int]) -> bool:
"""True when a full diagnostic bundle should be emitted (rank 0 only)."""
if _DETAIL_EVERY <= 0 or _RANK != 0:
return False
if global_step is None:
return False
return int(global_step) % _DETAIL_EVERY == 0
def is_enabled() -> bool:
return _DEBUG_ENABLED
def set_step_label(label: str) -> None:
global _STEP_LABEL
_STEP_LABEL = label
def set_detail_step(global_step: Optional[int]) -> None:
global _DETAIL_STEP
_DETAIL_STEP = global_step
def get_detail_step() -> Optional[int]:
return _DETAIL_STEP
def _next_call_id(stage: str) -> str:
global _CALL_COUNTER
_CALL_COUNTER += 1
return f"{_CALL_COUNTER}:{stage}"
def _fmt(value: Any, max_len: int = 240) -> str:
if value is None:
return "None"
try:
import torch
if isinstance(value, torch.Tensor):
return (
f"Tensor(shape={tuple(value.shape)}, dtype={value.dtype}, "
f"device={value.device}, numel={value.numel()})"
)
except ImportError:
pass
if isinstance(value, (list, tuple)):
if len(value) > 12:
head = ", ".join(_fmt(v, max_len=40) for v in value[:6])
return f"[{head}, ... +{len(value) - 6} more, total={len(value)}]"
return "[" + ", ".join(_fmt(v, max_len=40) for v in value) + "]"
if isinstance(value, dict):
text = json.dumps(value, ensure_ascii=False, default=str)
else:
text = repr(value)
if len(text) > max_len:
return text[: max_len - 3] + "..."
return text
def _prefix(stage: str, call_id: Optional[str] = None) -> str:
cid = call_id or _next_call_id(stage)
ts = time.strftime("%Y-%m-%d %H:%M:%S")
return f"[OPSD-DEBUG][{ts}][rank={_RANK}/{_WORLD_SIZE}][step={_STEP_LABEL}][{cid}]"
def log(stage: str, msg: str, **fields: Any) -> None:
if not _DEBUG_ENABLED:
return
call_id = _next_call_id(stage)
extra = ""
if fields:
extra = " | " + " | ".join(f"{k}={_fmt(v)}" for k, v in fields.items())
print(f"{_prefix(stage, call_id)} {msg}{extra}", flush=True)
def _detail_prefix(global_step: int, section: str) -> str:
ts = time.strftime("%Y-%m-%d %H:%M:%S")
return (
f"[OPSD-DETAIL][{ts}][rank={_RANK}/{_WORLD_SIZE}]"
f"[step={global_step}][every={_DETAIL_EVERY}][{section}]"
)
def log_detail_banner(global_step: int, title: str) -> None:
if not should_log_detail(global_step):
return
bar = "=" * 20
print(f"{_detail_prefix(global_step, 'BANNER')} {bar} {title} {bar}", flush=True)
def log_detail(section: str, msg: str, global_step: Optional[int] = None, **fields: Any) -> None:
"""Full-detail diagnostic line (periodic, rank 0). Independent of verbose OPSD-DEBUG."""
step = global_step if global_step is not None else _DETAIL_STEP
if step is None or isinstance(step, str):
return
if not should_log_detail(step):
return
extra = ""
if fields:
extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=800)}" for k, v in fields.items())
print(f"{_detail_prefix(step, section)} {msg}{extra}", flush=True)
def _probe_prefix(section: str) -> str:
ts = time.strftime("%Y-%m-%d %H:%M:%S")
step = _DETAIL_STEP if _DETAIL_STEP is not None else "?"
return (
f"[OPSD-PROBE][{ts}][rank={_RANK}/{_WORLD_SIZE}]"
f"[global_step={step}][{_STEP_LABEL}][{section}]"
)
def log_probe(section: str, msg: str, **fields: Any) -> None:
"""Lightweight per-generate diagnostic (rank 0). Independent of OPSD-DEBUG verbosity."""
if not should_log_probe():
return
extra = ""
if fields:
extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
print(f"{_probe_prefix(section)} {msg}{extra}", flush=True)
def _gendbg_prefix(section: str) -> str:
ts = time.strftime("%Y-%m-%d %H:%M:%S")
step = _DETAIL_STEP if _DETAIL_STEP is not None else "?"
return (
f"[OPSD-GENDBG][{ts}][rank={_RANK}/{_WORLD_SIZE}]"
f"[global_step={step}][{_STEP_LABEL}][{section}]"
)
def should_log_gendbg() -> bool:
"""True when deep generate diagnostics should run (rank 0 only)."""
return _PROBE_ON_GENERATE and _RANK == 0
def health_monitor_enabled() -> bool:
return _HEALTH_MONITOR_ENABLED and _RANK == 0
def should_log_health_on_generate() -> bool:
return health_monitor_enabled() and _HEALTH_LOG_ON_GENERATE and should_log_probe()
def should_log_health_every_step() -> bool:
return health_monitor_enabled() and _HEALTH_LOG_EVERY_STEP
def should_log_health_detail_bundle() -> bool:
return health_monitor_enabled() and _HEALTH_LOG_DETAIL_BUNDLE
def should_log_health_alerts_immediately() -> bool:
return health_monitor_enabled() and _HEALTH_LOG_ALERTS_IMMEDIATELY
def _health_prefix(section: str, global_step: Optional[int] = None) -> str:
ts = time.strftime("%Y-%m-%d %H:%M:%S")
step = global_step if global_step is not None else (_DETAIL_STEP if _DETAIL_STEP is not None else "?")
return f"[OPSD-HEALTH][{ts}][rank={_RANK}/{_WORLD_SIZE}][global_step={step}][{section}]"
def log_health(section: str, msg: str, global_step: Optional[int] = None, **fields: Any) -> None:
"""L1/L2/L4 health lines (rank 0)."""
if not health_monitor_enabled():
return
extra = ""
if fields:
extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
print(f"{_health_prefix(section, global_step)} {msg}{extra}", flush=True)
def log_health_detail_banner(global_step: int, title: str) -> None:
if not should_log_health_detail_bundle() or not should_log_detail(global_step):
return
bar = "=" * 20
print(f"{_detail_prefix(global_step, 'health')} {bar} {title} {bar}", flush=True)
def log_health_detail(section: str, msg: str, global_step: int, **fields: Any) -> None:
"""L3 periodic health bundle (rank 0, same cadence as OPSD-DETAIL)."""
if not should_log_health_detail_bundle() or not should_log_detail(global_step):
return
extra = ""
if fields:
extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
print(f"{_detail_prefix(global_step, section)} {msg}{extra}", flush=True)
def log_gendbg(section: str, msg: str, **fields: Any) -> None:
"""Deep per-generate diagnostic (rank 0). Uses [OPSD-GENDBG] prefix."""
if not should_log_gendbg():
return
extra = ""
if fields:
extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
print(f"{_gendbg_prefix(section)} {msg}{extra}", flush=True)
def log_config(stage: str, title: str, config: dict[str, Any]) -> None:
if not _DEBUG_ENABLED:
return
log(stage, title, config=_fmt(config, max_len=2000))
def log_sync_point(stage: str, msg: str, **fields: Any) -> None:
"""Mark a point where all ranks must reach before a collective call."""
log(stage, f"[SYNC_POINT] {msg}", **fields)
def log_mode_summary(stage: str, prompt_modes: list[int], completion_modes: Optional[list[int]] = None) -> None:
if not _DEBUG_ENABLED:
return
prompt_names = [MODE_NAMES.get(m, str(m)) for m in prompt_modes]
fields: dict[str, Any] = {"prompt_modes": prompt_names}
if completion_modes is not None:
counts = {name: completion_modes.count(code) for code, name in MODE_NAMES.items()}
fields["completion_mode_counts"] = counts
log(stage, "mode routing summary", **fields)
def log_tensor(stage: str, name: str, tensor: Any) -> None:
if not _DEBUG_ENABLED:
return
log(stage, f"tensor `{name}`", value=_fmt(tensor))
def log_exception(stage: str, msg: str, exc: BaseException) -> None:
if not _DEBUG_ENABLED:
return
tb = traceback.format_exc()
log(stage, msg, error=repr(exc), traceback=_fmt(tb, max_len=4000))
@contextmanager
def timed(stage: str, msg: str = "", **fields: Any):
if not _DEBUG_ENABLED:
yield
return
t0 = time.perf_counter()
log(stage, f"START {msg}", **fields)
try:
yield
except Exception as exc:
log_exception(stage, f"FAILED {msg}", exc)
raise
finally:
elapsed = time.perf_counter() - t0
log(stage, f"END {msg}", elapsed_s=f"{elapsed:.4f}")