File size: 14,225 Bytes
36d0b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""Verbose debug logging for the OPSD / TriMode training pipeline."""
from __future__ import annotations

import json
import os
import time
import traceback
from contextlib import contextmanager
from typing import Any, Optional

_DEBUG_ENABLED = False
_DETAIL_EVERY = 10
_PROBE_ON_GENERATE = False
_PROBE_FIRST_TOKEN_LOGITS = True
_PROBE_PROMPT_TAIL_TOKENS = 16
_PROBE_LOG_MODEL_CONTEXT = True
_HEALTH_MONITOR_ENABLED = True
_HEALTH_LOG_ON_GENERATE = True
_HEALTH_LOG_EVERY_STEP = True
_HEALTH_LOG_DETAIL_BUNDLE = True
_HEALTH_LOG_ALERTS_IMMEDIATELY = True
_RANK = 0
_WORLD_SIZE = 1
_STEP_LABEL = "init"
_DETAIL_STEP: Optional[int] = None
_CALL_COUNTER = 0

MODE_NAMES = {0: "GRPO", 1: "OPSD", 2: "SFT"}


def _env_debug_enabled() -> bool:
    return os.environ.get("DYME_OPSD_DEBUG", "").strip().lower() in ("1", "true", "yes", "on")


def _env_detail_every() -> int:
    raw = os.environ.get("DYME_OPSD_DETAIL_EVERY", "").strip()
    if not raw:
        return 10
    try:
        return max(0, int(raw))
    except ValueError:
        return 10


def _env_probe_on_generate() -> Optional[bool]:
    raw = os.environ.get("DYME_OPSD_PROBE_ON_GENERATE", "").strip().lower()
    if not raw:
        return None
    return raw in ("1", "true", "yes", "on")


def _env_probe_first_token_logits() -> Optional[bool]:
    raw = os.environ.get("DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS", "").strip().lower()
    if not raw:
        return None
    return raw in ("1", "true", "yes", "on")


def _env_probe_prompt_tail_tokens() -> Optional[int]:
    raw = os.environ.get("DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS", "").strip()
    if not raw:
        return None
    try:
        return max(1, int(raw))
    except ValueError:
        return 16


def _env_probe_log_model_context() -> Optional[bool]:
    raw = os.environ.get("DYME_OPSD_PROBE_LOG_MODEL_CONTEXT", "").strip().lower()
    if not raw:
        return None
    return raw in ("1", "true", "yes", "on")


def _env_health_monitor_enabled() -> Optional[bool]:
    raw = os.environ.get("DYME_OPSD_HEALTH_MONITOR", "").strip().lower()
    if not raw:
        return None
    return raw in ("1", "true", "yes", "on")


def configure(
    *,
    enabled: Optional[bool] = None,
    detail_every: Optional[int] = None,
    probe_on_generate: Optional[bool] = None,
    probe_first_token_logits: Optional[bool] = None,
    probe_prompt_tail_tokens: Optional[int] = None,
    probe_log_model_context: Optional[bool] = None,
    health_monitor_enabled: Optional[bool] = None,
    health_log_on_generate: Optional[bool] = None,
    health_log_every_step: Optional[bool] = None,
    health_log_detail_bundle: Optional[bool] = None,
    health_log_alerts_immediately: Optional[bool] = None,
    rank: Optional[int] = None,
    world_size: Optional[int] = None,
) -> bool:
    """Configure global OPSD debug logging. Returns whether debug is enabled."""
    global _DEBUG_ENABLED, _DETAIL_EVERY, _PROBE_ON_GENERATE
    global _PROBE_FIRST_TOKEN_LOGITS, _PROBE_PROMPT_TAIL_TOKENS, _PROBE_LOG_MODEL_CONTEXT
    global _HEALTH_MONITOR_ENABLED, _HEALTH_LOG_ON_GENERATE, _HEALTH_LOG_EVERY_STEP
    global _HEALTH_LOG_DETAIL_BUNDLE, _HEALTH_LOG_ALERTS_IMMEDIATELY
    global _RANK, _WORLD_SIZE
    if enabled is None:
        enabled = _env_debug_enabled()
    _DEBUG_ENABLED = bool(enabled)
    if detail_every is not None:
        _DETAIL_EVERY = max(0, int(detail_every))
    elif _env_detail_every() != 10 or os.environ.get("DYME_OPSD_DETAIL_EVERY"):
        _DETAIL_EVERY = _env_detail_every()
    env_probe = _env_probe_on_generate()
    if probe_on_generate is not None:
        _PROBE_ON_GENERATE = bool(probe_on_generate)
    elif env_probe is not None:
        _PROBE_ON_GENERATE = env_probe
    env_first_logits = _env_probe_first_token_logits()
    if probe_first_token_logits is not None:
        _PROBE_FIRST_TOKEN_LOGITS = bool(probe_first_token_logits)
    elif env_first_logits is not None:
        _PROBE_FIRST_TOKEN_LOGITS = env_first_logits
    env_tail = _env_probe_prompt_tail_tokens()
    if probe_prompt_tail_tokens is not None:
        _PROBE_PROMPT_TAIL_TOKENS = max(1, int(probe_prompt_tail_tokens))
    elif env_tail is not None:
        _PROBE_PROMPT_TAIL_TOKENS = env_tail
    env_model_ctx = _env_probe_log_model_context()
    if probe_log_model_context is not None:
        _PROBE_LOG_MODEL_CONTEXT = bool(probe_log_model_context)
    elif env_model_ctx is not None:
        _PROBE_LOG_MODEL_CONTEXT = env_model_ctx
    env_health = _env_health_monitor_enabled()
    if health_monitor_enabled is not None:
        _HEALTH_MONITOR_ENABLED = bool(health_monitor_enabled)
    elif env_health is not None:
        _HEALTH_MONITOR_ENABLED = env_health
    if health_log_on_generate is not None:
        _HEALTH_LOG_ON_GENERATE = bool(health_log_on_generate)
    if health_log_every_step is not None:
        _HEALTH_LOG_EVERY_STEP = bool(health_log_every_step)
    if health_log_detail_bundle is not None:
        _HEALTH_LOG_DETAIL_BUNDLE = bool(health_log_detail_bundle)
    if health_log_alerts_immediately is not None:
        _HEALTH_LOG_ALERTS_IMMEDIATELY = bool(health_log_alerts_immediately)
    if rank is not None:
        _RANK = rank
    if world_size is not None:
        _WORLD_SIZE = world_size
    return _DEBUG_ENABLED


def detail_every() -> int:
    return _DETAIL_EVERY


def probe_on_generate() -> bool:
    return _PROBE_ON_GENERATE


def probe_first_token_logits() -> bool:
    return _PROBE_FIRST_TOKEN_LOGITS


def probe_prompt_tail_tokens() -> int:
    return _PROBE_PROMPT_TAIL_TOKENS


def probe_log_model_context() -> bool:
    return _PROBE_LOG_MODEL_CONTEXT


def should_log_probe() -> bool:
    """True when lightweight per-generate probe should run (rank 0 only)."""
    return _PROBE_ON_GENERATE and _RANK == 0


def should_log_detail(global_step: Optional[int]) -> bool:
    """True when a full diagnostic bundle should be emitted (rank 0 only)."""
    if _DETAIL_EVERY <= 0 or _RANK != 0:
        return False
    if global_step is None:
        return False
    return int(global_step) % _DETAIL_EVERY == 0


def is_enabled() -> bool:
    return _DEBUG_ENABLED


def set_step_label(label: str) -> None:
    global _STEP_LABEL
    _STEP_LABEL = label


def set_detail_step(global_step: Optional[int]) -> None:
    global _DETAIL_STEP
    _DETAIL_STEP = global_step


def get_detail_step() -> Optional[int]:
    return _DETAIL_STEP


def _next_call_id(stage: str) -> str:
    global _CALL_COUNTER
    _CALL_COUNTER += 1
    return f"{_CALL_COUNTER}:{stage}"


def _fmt(value: Any, max_len: int = 240) -> str:
    if value is None:
        return "None"
    try:
        import torch

        if isinstance(value, torch.Tensor):
            return (
                f"Tensor(shape={tuple(value.shape)}, dtype={value.dtype}, "
                f"device={value.device}, numel={value.numel()})"
            )
    except ImportError:
        pass
    if isinstance(value, (list, tuple)):
        if len(value) > 12:
            head = ", ".join(_fmt(v, max_len=40) for v in value[:6])
            return f"[{head}, ... +{len(value) - 6} more, total={len(value)}]"
        return "[" + ", ".join(_fmt(v, max_len=40) for v in value) + "]"
    if isinstance(value, dict):
        text = json.dumps(value, ensure_ascii=False, default=str)
    else:
        text = repr(value)
    if len(text) > max_len:
        return text[: max_len - 3] + "..."
    return text


def _prefix(stage: str, call_id: Optional[str] = None) -> str:
    cid = call_id or _next_call_id(stage)
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    return f"[OPSD-DEBUG][{ts}][rank={_RANK}/{_WORLD_SIZE}][step={_STEP_LABEL}][{cid}]"


def log(stage: str, msg: str, **fields: Any) -> None:
    if not _DEBUG_ENABLED:
        return
    call_id = _next_call_id(stage)
    extra = ""
    if fields:
        extra = " | " + " | ".join(f"{k}={_fmt(v)}" for k, v in fields.items())
    print(f"{_prefix(stage, call_id)} {msg}{extra}", flush=True)


def _detail_prefix(global_step: int, section: str) -> str:
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    return (
        f"[OPSD-DETAIL][{ts}][rank={_RANK}/{_WORLD_SIZE}]"
        f"[step={global_step}][every={_DETAIL_EVERY}][{section}]"
    )


def log_detail_banner(global_step: int, title: str) -> None:
    if not should_log_detail(global_step):
        return
    bar = "=" * 20
    print(f"{_detail_prefix(global_step, 'BANNER')} {bar} {title} {bar}", flush=True)


def log_detail(section: str, msg: str, global_step: Optional[int] = None, **fields: Any) -> None:
    """Full-detail diagnostic line (periodic, rank 0). Independent of verbose OPSD-DEBUG."""
    step = global_step if global_step is not None else _DETAIL_STEP
    if step is None or isinstance(step, str):
        return
    if not should_log_detail(step):
        return
    extra = ""
    if fields:
        extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=800)}" for k, v in fields.items())
    print(f"{_detail_prefix(step, section)} {msg}{extra}", flush=True)


def _probe_prefix(section: str) -> str:
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    step = _DETAIL_STEP if _DETAIL_STEP is not None else "?"
    return (
        f"[OPSD-PROBE][{ts}][rank={_RANK}/{_WORLD_SIZE}]"
        f"[global_step={step}][{_STEP_LABEL}][{section}]"
    )


def log_probe(section: str, msg: str, **fields: Any) -> None:
    """Lightweight per-generate diagnostic (rank 0). Independent of OPSD-DEBUG verbosity."""
    if not should_log_probe():
        return
    extra = ""
    if fields:
        extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
    print(f"{_probe_prefix(section)} {msg}{extra}", flush=True)


def _gendbg_prefix(section: str) -> str:
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    step = _DETAIL_STEP if _DETAIL_STEP is not None else "?"
    return (
        f"[OPSD-GENDBG][{ts}][rank={_RANK}/{_WORLD_SIZE}]"
        f"[global_step={step}][{_STEP_LABEL}][{section}]"
    )


def should_log_gendbg() -> bool:
    """True when deep generate diagnostics should run (rank 0 only)."""
    return _PROBE_ON_GENERATE and _RANK == 0


def health_monitor_enabled() -> bool:
    return _HEALTH_MONITOR_ENABLED and _RANK == 0


def should_log_health_on_generate() -> bool:
    return health_monitor_enabled() and _HEALTH_LOG_ON_GENERATE and should_log_probe()


def should_log_health_every_step() -> bool:
    return health_monitor_enabled() and _HEALTH_LOG_EVERY_STEP


def should_log_health_detail_bundle() -> bool:
    return health_monitor_enabled() and _HEALTH_LOG_DETAIL_BUNDLE


def should_log_health_alerts_immediately() -> bool:
    return health_monitor_enabled() and _HEALTH_LOG_ALERTS_IMMEDIATELY


def _health_prefix(section: str, global_step: Optional[int] = None) -> str:
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    step = global_step if global_step is not None else (_DETAIL_STEP if _DETAIL_STEP is not None else "?")
    return f"[OPSD-HEALTH][{ts}][rank={_RANK}/{_WORLD_SIZE}][global_step={step}][{section}]"


def log_health(section: str, msg: str, global_step: Optional[int] = None, **fields: Any) -> None:
    """L1/L2/L4 health lines (rank 0)."""
    if not health_monitor_enabled():
        return
    extra = ""
    if fields:
        extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
    print(f"{_health_prefix(section, global_step)} {msg}{extra}", flush=True)


def log_health_detail_banner(global_step: int, title: str) -> None:
    if not should_log_health_detail_bundle() or not should_log_detail(global_step):
        return
    bar = "=" * 20
    print(f"{_detail_prefix(global_step, 'health')} {bar} {title} {bar}", flush=True)


def log_health_detail(section: str, msg: str, global_step: int, **fields: Any) -> None:
    """L3 periodic health bundle (rank 0, same cadence as OPSD-DETAIL)."""
    if not should_log_health_detail_bundle() or not should_log_detail(global_step):
        return
    extra = ""
    if fields:
        extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
    print(f"{_detail_prefix(global_step, section)} {msg}{extra}", flush=True)


def log_gendbg(section: str, msg: str, **fields: Any) -> None:
    """Deep per-generate diagnostic (rank 0). Uses [OPSD-GENDBG] prefix."""
    if not should_log_gendbg():
        return
    extra = ""
    if fields:
        extra = " | " + " | ".join(f"{k}={_fmt(v, max_len=1200)}" for k, v in fields.items())
    print(f"{_gendbg_prefix(section)} {msg}{extra}", flush=True)


def log_config(stage: str, title: str, config: dict[str, Any]) -> None:
    if not _DEBUG_ENABLED:
        return
    log(stage, title, config=_fmt(config, max_len=2000))


def log_sync_point(stage: str, msg: str, **fields: Any) -> None:
    """Mark a point where all ranks must reach before a collective call."""
    log(stage, f"[SYNC_POINT] {msg}", **fields)


def log_mode_summary(stage: str, prompt_modes: list[int], completion_modes: Optional[list[int]] = None) -> None:
    if not _DEBUG_ENABLED:
        return
    prompt_names = [MODE_NAMES.get(m, str(m)) for m in prompt_modes]
    fields: dict[str, Any] = {"prompt_modes": prompt_names}
    if completion_modes is not None:
        counts = {name: completion_modes.count(code) for code, name in MODE_NAMES.items()}
        fields["completion_mode_counts"] = counts
    log(stage, "mode routing summary", **fields)


def log_tensor(stage: str, name: str, tensor: Any) -> None:
    if not _DEBUG_ENABLED:
        return
    log(stage, f"tensor `{name}`", value=_fmt(tensor))


def log_exception(stage: str, msg: str, exc: BaseException) -> None:
    if not _DEBUG_ENABLED:
        return
    tb = traceback.format_exc()
    log(stage, msg, error=repr(exc), traceback=_fmt(tb, max_len=4000))


@contextmanager
def timed(stage: str, msg: str = "", **fields: Any):
    if not _DEBUG_ENABLED:
        yield
        return
    t0 = time.perf_counter()
    log(stage, f"START {msg}", **fields)
    try:
        yield
    except Exception as exc:
        log_exception(stage, f"FAILED {msg}", exc)
        raise
    finally:
        elapsed = time.perf_counter() - t0
        log(stage, f"END {msg}", elapsed_s=f"{elapsed:.4f}")