"""Periodic full-detail diagnostics for weak reward / gradient signals.""" from __future__ import annotations import os import re from typing import Any, Optional import torch import torch.nn as nn import torch.nn.functional as F from opsd_utils import debug_log as opsd_debug from opsd_utils.vocab_align import align_cross_model_logits PAREN_TOKEN_ID = 340 # Reuse logits from the OPSD loss path instead of running extra forwards. _OPSD_JSD_DETAIL_CAPTURE: dict[str, Any] = { "active": False, "global_step": None, "target_indices": set(), "entries": [], "skipped_memory": False, "skip_reason": "", "max_samples": 2, } def _detail_min_free_gib() -> float: raw = os.environ.get("DYME_OPSD_DETAIL_MIN_FREE_GB", "").strip() if not raw: return 4.0 try: return max(0.0, float(raw)) except ValueError: return 4.0 def cuda_free_gib(device: Optional[torch.device | int] = None) -> Optional[float]: if not torch.cuda.is_available(): return None try: if device is None: free_bytes, _ = torch.cuda.mem_get_info() else: dev = torch.device(device) if not isinstance(device, torch.device) else device with torch.cuda.device(dev): free_bytes, _ = torch.cuda.mem_get_info() return free_bytes / (1024**3) except Exception: return None def check_detail_cuda_memory( min_free_gib: Optional[float] = None, device: Optional[torch.device | int] = None, ) -> tuple[bool, str, Optional[float]]: """Return (ok, reason, free_gib). Skips heavy detail work when GPU headroom is low.""" threshold = _detail_min_free_gib() if min_free_gib is None else max(0.0, float(min_free_gib)) if not torch.cuda.is_available(): return True, "", None free_gib = cuda_free_gib(device) if free_gib is None: return True, "", None if free_gib < threshold: return ( False, f"cuda_free_gib={free_gib:.2f} < min_free_gib={threshold:.2f}", free_gib, ) return True, "", free_gib def begin_opsd_jsd_detail_capture( global_step: int, opsd_indices: list[int], max_samples: int = 2, ) -> None: """Prepare to record JSD stats during OPSD loss (no extra model forwards).""" _OPSD_JSD_DETAIL_CAPTURE.update( active=False, global_step=global_step, target_indices=set(), entries=[], skipped_memory=False, skip_reason="", max_samples=max(1, int(max_samples)), ) if not opsd_debug.should_log_detail(global_step) or not opsd_indices: return ok, reason, free_gib = check_detail_cuda_memory() if not ok: _OPSD_JSD_DETAIL_CAPTURE["skipped_memory"] = True _OPSD_JSD_DETAIL_CAPTURE["skip_reason"] = reason opsd_debug.log_detail( "opsd_jsd", "skip JSD detail capture (CUDA memory guard)", global_step=global_step, reason=reason, cuda_free_gib=free_gib, min_free_gib=_detail_min_free_gib(), ) return _OPSD_JSD_DETAIL_CAPTURE["active"] = True _OPSD_JSD_DETAIL_CAPTURE["target_indices"] = set(opsd_indices[: _OPSD_JSD_DETAIL_CAPTURE["max_samples"]]) def maybe_capture_opsd_jsd_detail( *, global_idx: int, student_logits: torch.Tensor, teacher_logits: torch.Tensor, completion_mask: torch.Tensor, completion_ids: torch.Tensor, beta: float, tokenizer: Any = None, student_prompt_len: Optional[int] = None, teacher_prompt_len: Optional[int] = None, ) -> None: """Record token-level JSD stats from logits already computed in the loss path.""" capture = _OPSD_JSD_DETAIL_CAPTURE if not capture["active"] or global_idx not in capture["target_indices"]: return try: with torch.no_grad(): s_logits, t_logits = align_cross_model_logits( student_logits.detach(), teacher_logits.detach(), ) stats = jsd_token_stats(s_logits, t_logits, completion_mask.float(), beta=beta) stats["sample_index"] = global_idx if student_prompt_len is not None: stats["student_prompt_len"] = int(student_prompt_len) if teacher_prompt_len is not None: stats["teacher_prompt_len"] = int(teacher_prompt_len) if tokenizer is not None: decoded = tokenizer.decode( completion_ids[0][completion_mask[0].bool()], skip_special_tokens=True, ) stats["completion_text"] = _preview_text(decoded) capture["entries"].append(stats) except RuntimeError as exc: if "out of memory" in str(exc).lower(): if torch.cuda.is_available(): torch.cuda.empty_cache() opsd_debug.log_detail( "opsd_jsd", f"skip JSD detail for sample[{global_idx}] (OOM during stats)", global_step=capture.get("global_step"), error=repr(exc), ) return raise def _tensor_stats(t: torch.Tensor, name: str) -> dict[str, Any]: if t is None or not isinstance(t, torch.Tensor) or t.numel() == 0: return {name: "empty"} with torch.no_grad(): flat = t.detach().float().reshape(-1) return { f"{name}/shape": tuple(t.shape), f"{name}/mean": float(flat.mean().item()), f"{name}/std": float(flat.std(unbiased=False).item()) if flat.numel() > 1 else 0.0, f"{name}/min": float(flat.min().item()), f"{name}/max": float(flat.max().item()), f"{name}/abs_mean": float(flat.abs().mean().item()), } def _preview_text(text: str, max_len: int = 320) -> str: text = (text or "").replace("\n", "\\n") if len(text) > max_len: return text[: max_len - 3] + "..." return text or "" def _generation_config_summary(generation_config: Any) -> dict[str, Any]: if generation_config is None: return {} keys = ( "max_new_tokens", "do_sample", "temperature", "top_p", "top_k", "min_p", "repetition_penalty", "eos_token_id", "pad_token_id", "bos_token_id", ) out: dict[str, Any] = {} for k in keys: if hasattr(generation_config, k): v = getattr(generation_config, k) out[k] = v.tolist() if isinstance(v, torch.Tensor) else v return out def _slice_generate_inputs(batch: dict[str, Any], index: int, batch_size: int) -> dict[str, Any]: """Take one row from batched generate/forward tensors (VLM-safe).""" sliced: dict[str, Any] = {} for key, value in batch.items(): if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.size(0) == batch_size: sliced[key] = value[index : index + 1] else: sliced[key] = value return sliced def _last_valid_prompt_positions(prompt_mask: torch.Tensor) -> torch.Tensor: """Return index of last valid (non-pad) prompt token per row; supports left padding.""" with torch.no_grad(): seq_len = prompt_mask.size(1) positions = torch.arange(seq_len, device=prompt_mask.device).expand_as(prompt_mask) masked = torch.where(prompt_mask.bool(), positions, torch.full_like(positions, -1)) return masked.max(dim=1).values def _max_same_token_run(token_ids: list[int]) -> tuple[int, Optional[int]]: """Longest run of identical consecutive token ids.""" if not token_ids: return 0, None best_run = 1 best_tok = token_ids[0] run = 1 for i in range(1, len(token_ids)): if token_ids[i] == token_ids[i - 1]: run += 1 if run > best_run: best_run = run best_tok = token_ids[i] else: run = 1 return best_run, best_tok def _detect_single_token_repeat(token_ids: list[int], min_run: int = 8) -> bool: return _max_same_token_run(token_ids)[0] >= min_run def _detect_char_repeat(text: str, min_run: int = 6) -> bool: """Detect consecutive repeated characters (CJK or ASCII), e.g. 其其其.""" if not text or len(text) < min_run: return False run = 1 for i in range(1, len(text)): if text[i] == text[i - 1] and not text[i].isspace(): run += 1 if run >= min_run: return True else: run = 1 return False def _count_char_repeat_samples(completions: Optional[list[str]], min_run: int = 6) -> int: if not completions: return 0 return sum(1 for c in completions if _detect_char_repeat(c or "", min_run=min_run)) def _detect_repeat_loop(token_ids: list[int], min_repeats: int = 4, ngram: int = 3) -> bool: if len(token_ids) < ngram * min_repeats: return False last_start = len(token_ids) - ngram * min_repeats + 1 for start in range(max(0, last_start)): gram = token_ids[start : start + ngram] repeats = 1 pos = start + ngram while pos + ngram <= len(token_ids) and token_ids[pos : pos + ngram] == gram: repeats += 1 pos += ngram if repeats >= min_repeats: return True return False def _detect_degeneration( token_ids: list[int], text: str, *, answer_flag: str = "Answer:", min_single_token_run: int = 8, require_answer_flag: bool = True, ) -> tuple[bool, list[str]]: """Heuristics for repetitive / format-broken completions.""" reasons: list[str] = [] if _detect_single_token_repeat(token_ids, min_run=min_single_token_run): run_len, tok = _max_same_token_run(token_ids) reasons.append(f"SINGLE_TOKEN_REPEAT(run={run_len},tok={tok})") if _detect_repeat_loop(token_ids): reasons.append("NGRAM_REPEAT") if len(token_ids) >= 16: unique_ratio = len(set(token_ids)) / len(token_ids) if unique_ratio < 0.12: reasons.append(f"LOW_UNIQUE_RATIO({unique_ratio:.3f})") if require_answer_flag: answer_count = len(re.findall(f"(?i){re.escape(answer_flag)}", text or "")) if answer_count != 1: reasons.append(f"ANSWER_FLAG_COUNT({answer_count})") if _detect_char_repeat(text or ""): reasons.append("CHAR_REPEAT") return bool(reasons), reasons def _count_answer_flag(text: str, answer_flag: str = "Answer:") -> int: return len(re.findall(f"(?i){re.escape(answer_flag)}", text or "")) def is_degenerate_completion( token_ids: list[int], text: str, *, answer_flag: str = "Answer:", min_single_token_run: int = 8, require_answer_flag: bool = True, ) -> bool: """Return True when completion looks like a repetition / format-broken sample.""" is_deg, _ = _detect_degeneration( token_ids, text, answer_flag=answer_flag, min_single_token_run=min_single_token_run, require_answer_flag=require_answer_flag, ) return is_deg def _count_paren_then_eos( completion_ids: torch.Tensor, completion_mask: torch.Tensor, eos_id: Optional[int], ) -> int: if eos_id is None or completion_ids.size(0) == 0: return 0 count = 0 with torch.no_grad(): lengths = completion_mask.sum(dim=1) for i in range(completion_ids.size(0)): eff = int(lengths[i].item()) if eff <= 0: continue first = int(completion_ids[i, 0].item()) if first != PAREN_TOKEN_ID: continue if eff <= 2: count += 1 elif eff >= 2 and int(completion_ids[i, 1].item()) == eos_id: count += 1 return count def summarize_generate_probe_stats( completion_ids: torch.Tensor, completion_mask: torch.Tensor, is_eos: torch.Tensor, eos_id: Optional[int], completions: Optional[list[str]] = None, answer_flag: str = "Answer:", max_completion_length: Optional[int] = None, ) -> dict[str, Any]: with torch.no_grad(): lengths = completion_mask.sum(dim=1).float() has_eos = is_eos.any(dim=1) paren_then_eos = _count_paren_then_eos(completion_ids, completion_mask, eos_id) repeat_loop = 0 degenerate_count = 0 max_run_lengths: list[float] = [] unique_ratios: list[float] = [] answer_flag_ok = 0 clipped_count = 0 degenerate_format_count = 0 degenerate_repeat_count = 0 format_without_thinking = 0 for i in range(completion_ids.size(0)): eff = int(lengths[i].item()) if eff <= 0: continue ids = completion_ids[i, :eff].tolist() text = completions[i] if completions and i < len(completions) else "" if _detect_repeat_loop(ids) or _detect_single_token_repeat(ids): repeat_loop += 1 run_len, _ = _max_same_token_run(ids) max_run_lengths.append(float(run_len)) unique_ratios.append(len(set(ids)) / max(len(ids), 1)) if _count_answer_flag(text, answer_flag) == 1: answer_flag_ok += 1 thinking = (text or "").lower().split(answer_flag.lower())[0] if len(thinking.strip()) < 8: format_without_thinking += 1 else: degenerate_format_count += 1 if max_completion_length is not None and eff >= max_completion_length - 1: clipped_count += 1 is_deg, reasons = _detect_degeneration(ids, text, answer_flag=answer_flag) if is_deg: degenerate_count += 1 non_flag = [r for r in reasons if not r.startswith("ANSWER_FLAG")] if non_flag: degenerate_repeat_count += 1 char_repeat_count = _count_char_repeat_samples(completions) n = max(completion_ids.size(0), 1) return { "effective_tokens_mean": float(lengths.mean().item()), "char_repeat_count": char_repeat_count, "one_token_count": int((lengths == 1).sum().item()), "paren_then_eos_count": paren_then_eos, "repeat_loop_count": repeat_loop, "eos_terminated_rate": float(has_eos.float().mean().item()), "degenerate_count": degenerate_count, "degenerate_rate": degenerate_count / n, "degenerate_rate_format": degenerate_format_count / n, "degenerate_rate_repeat": degenerate_repeat_count / n, "format_without_thinking_rate": format_without_thinking / n, "max_token_run_mean": float(sum(max_run_lengths) / len(max_run_lengths)) if max_run_lengths else 0.0, "unique_token_ratio_mean": float(sum(unique_ratios) / len(unique_ratios)) if unique_ratios else 0.0, "answer_flag_exactly_once_rate": answer_flag_ok / n, "clipped_count": clipped_count, "clipped_rate": clipped_count / n, } def log_generate_context( *, global_step: int, trainer_step: Optional[int], generate_call_index: int, model: Any, model_wrapped: Any, gradient_checkpointing: bool, generation_config: Any, is_fsdp_enabled: bool, generate_runs_under_no_grad: bool, ) -> None: if not opsd_debug.should_log_gendbg() or not opsd_debug.probe_log_model_context(): return opsd_debug.set_detail_step(global_step) dropout_in_train = sum( 1 for m in model.modules() if isinstance(m, nn.Dropout) and m.training ) opsd_debug.log_gendbg( "context", "generate model context", generate_call_index=generate_call_index, trainer_step=trainer_step, global_step=global_step, model_training=bool(getattr(model, "training", None)), model_wrapped_training=bool(getattr(model_wrapped, "training", None)), gradient_checkpointing=gradient_checkpointing, generation_use_cache=getattr(generation_config, "use_cache", None), is_fsdp_enabled=is_fsdp_enabled, generate_runs_under_no_grad=generate_runs_under_no_grad, dropout_modules_in_train=dropout_in_train, generation_config=_generation_config_summary(generation_config), ) def log_prompt_tail_probe( *, global_step: int, trainer_step: Optional[int], generate_call_index: int, prompt_ids: torch.Tensor, prompt_mask: torch.Tensor, tokenizer: Any, sample_count: int = 4, tail_tokens: Optional[int] = None, ) -> None: if not opsd_debug.should_log_gendbg(): return opsd_debug.set_detail_step(global_step) n_tail = tail_tokens if tail_tokens is not None else opsd_debug.probe_prompt_tail_tokens() last_pos = _last_valid_prompt_positions(prompt_mask) n = min(sample_count, prompt_ids.size(0)) for i in range(n): end = int(last_pos[i].item()) + 1 start = max(0, end - n_tail) tail_ids = prompt_ids[i, start:end].tolist() eff_len = int(prompt_mask[i].sum().item()) tail_decode = tokenizer.decode(tail_ids, skip_special_tokens=False) opsd_debug.log_gendbg( "prompt_tail", f"sample[{i}]", generate_call_index=generate_call_index, trainer_step=trainer_step, prompt_effective_len=eff_len, last_valid_idx=end - 1, prompt_tail_token_ids=tail_ids, prompt_tail_decode=_preview_text(tail_decode, 400), ) def summarize_first_token_logits_stats( p_greedy_values: list[float], p_eos_values: list[float], entropy_values: list[float], p_answer_values: Optional[list[float]] = None, ) -> dict[str, float]: """Aggregate first-token logit probe scalars across samples.""" def _mean(vals: list[float]) -> float: return float(sum(vals) / len(vals)) if vals else 0.0 out = { "p_greedy_first": _mean(p_greedy_values), "p_eos_first": _mean(p_eos_values), "entropy_first": _mean(entropy_values), } if p_answer_values: out["p_answer_first"] = _mean(p_answer_values) return out def answer_first_token_id(tokenizer: Any) -> Optional[int]: for piece in ("Answer", "Answer:"): ids = tokenizer.encode(piece, add_special_tokens=False) if ids: return int(ids[0]) return None def log_first_token_logits_probe( *, global_step: int, trainer_step: Optional[int], generate_call_index: int, unwrapped_model: Any, prompt_inputs_generate: dict[str, Any], prompt_mask: torch.Tensor, tokenizer: Any, sample_count: int = 4, ) -> dict[str, Any]: """Forward once before generate; return greedy ids and aggregated logit stats.""" greedy_by_sample: dict[int, int] = {} p_greedy_vals: list[float] = [] p_eos_vals: list[float] = [] entropy_vals: list[float] = [] p_answer_vals: list[float] = [] if not opsd_debug.should_log_gendbg() or not opsd_debug.probe_first_token_logits(): return { "greedy_by_sample": greedy_by_sample, **summarize_first_token_logits_stats([], [], [], []), } opsd_debug.set_detail_step(global_step) eos_id = getattr(tokenizer, "eos_token_id", None) answer_tid = answer_first_token_id(tokenizer) forward_inputs = {k: v for k, v in prompt_inputs_generate.items() if k != "labels"} batch_size = prompt_mask.size(0) n = min(sample_count, batch_size) last_pos = _last_valid_prompt_positions(prompt_mask) for i in range(n): try: sample_inputs = _slice_generate_inputs(forward_inputs, i, batch_size) with torch.no_grad(): outputs = unwrapped_model(**sample_inputs, use_cache=False) logits = outputs.logits pos = int(last_pos[i].item()) next_logits = logits[0, pos, :].float() probs = F.softmax(next_logits, dim=-1) greedy_id = int(next_logits.argmax().item()) greedy_by_sample[i] = greedy_id entropy = float(-(probs * (probs + 1e-12).log()).sum().item()) fields: dict[str, Any] = { "generate_call_index": generate_call_index, "trainer_step": trainer_step, "last_prompt_idx": pos, "greedy_token_id": greedy_id, "p_greedy": float(probs[greedy_id].item()), "entropy": entropy, "probe_mode": "per_sample_forward", } if eos_id is not None: fields["p_eos"] = float(probs[eos_id].item()) if answer_tid is not None and answer_tid < probs.size(0): fields["p_answer_first"] = float(probs[answer_tid].item()) p_answer_vals.append(fields["p_answer_first"]) if PAREN_TOKEN_ID < probs.size(0): fields["p_token_340"] = float(probs[PAREN_TOKEN_ID].item()) topk = torch.topk(probs, k=min(5, probs.size(0))) fields["top5"] = [ (int(tid), float(p)) for tid, p in zip(topk.indices.tolist(), topk.values.tolist()) ] p_greedy_vals.append(fields["p_greedy"]) entropy_vals.append(entropy) if eos_id is not None: p_eos_vals.append(fields["p_eos"]) opsd_debug.log_gendbg("first_token_logits", f"sample[{i}]", **fields) except Exception as exc: opsd_debug.log_gendbg( "first_token_logits", f"FAILED forward for sample[{i}]", generate_call_index=generate_call_index, error=repr(exc), ) if torch.cuda.is_available(): torch.cuda.empty_cache() logits_stats = summarize_first_token_logits_stats( p_greedy_vals, p_eos_vals, entropy_vals, p_answer_vals ) return {"greedy_by_sample": greedy_by_sample, **logits_stats} def log_first_token_logits_match( *, generate_call_index: int, completion_ids: torch.Tensor, greedy_by_sample: dict[int, int], sample_count: int = 4, ) -> None: """Compare pre-generate greedy next token vs actual first generated token.""" if not opsd_debug.should_log_gendbg() or not opsd_debug.probe_first_token_logits(): return n = min(sample_count, completion_ids.size(0)) for i in range(n): if i not in greedy_by_sample or completion_ids.size(1) == 0: continue greedy_id = greedy_by_sample[i] actual_id = int(completion_ids[i, 0].item()) opsd_debug.log_gendbg( "first_token_match", f"sample[{i}]", generate_call_index=generate_call_index, greedy_token_id=greedy_id, actual_first_token=actual_id, greedy_matches_actual=(greedy_id == actual_id), ) def log_generate_delta( *, generate_call_index: int, current_stats: dict[str, Any], previous_stats: Optional[dict[str, Any]], ) -> None: if not opsd_debug.should_log_gendbg(): return fields: dict[str, Any] = { "generate_call_index": generate_call_index, "current": current_stats, } if previous_stats is not None: prev_idx = previous_stats.get("generate_call_index", generate_call_index - 1) fields["prev_generate_call_index"] = prev_idx for key in ( "effective_tokens_mean", "one_token_count", "paren_then_eos_count", "repeat_loop_count", "eos_terminated_rate", "degenerate_count", "degenerate_rate", "clipped_rate", "answer_flag_exactly_once_rate", ): cur = current_stats.get(key) prev = previous_stats.get(key) if cur is not None and prev is not None: if isinstance(cur, (int, float)) and isinstance(prev, (int, float)): fields[f"delta_{key}"] = cur - prev opsd_debug.log_gendbg("delta", "generate stats vs previous regenerate", **fields) def log_generate_probe( *, global_step: int, trainer_step: Optional[int], prompt_length: int, prompt_completion_ids: torch.Tensor, completion_ids: torch.Tensor, completion_mask: torch.Tensor, is_eos: torch.Tensor, eos_idx: torch.Tensor, completions: list[str], tokenizer: Any, generation_config: Any, max_completion_length: int, num_generations: int, sample_count: int = 4, generate_call_index: Optional[int] = None, answer_flag: str = "Answer:", source: str = "generate", ) -> dict[str, Any]: """Emit [OPSD-PROBE] on every (re)generate — catches 1-token / empty completions early.""" stats = summarize_generate_probe_stats( completion_ids, completion_mask, is_eos, getattr(tokenizer, "eos_token_id", None), completions=completions, answer_flag=answer_flag, max_completion_length=max_completion_length, ) if generate_call_index is not None: stats["generate_call_index"] = generate_call_index if not opsd_debug.should_log_probe(): return stats opsd_debug.set_detail_step(global_step) tok = tokenizer eos_id = getattr(tok, "eos_token_id", None) pad_id = getattr(tok, "pad_token_id", None) bos_id = getattr(tok, "bos_token_id", None) with torch.no_grad(): lengths = completion_mask.sum(dim=1).float() has_eos = is_eos.any(dim=1) raw_gen_len = completion_ids.size(1) opsd_debug.log_probe( "generate", "sft cold-start GT batch" if source == "sft_cold_start" else "raw generate summary", trainer_step=trainer_step, global_step=global_step, source=source, generate_call_index=generate_call_index, prompt_length=prompt_length, prompt_completion_shape=tuple(prompt_completion_ids.shape), completion_ids_shape=tuple(completion_ids.shape), raw_gen_tokens=raw_gen_len, max_completion_length=max_completion_length, num_generations=num_generations, batch_size=completion_ids.size(0), effective_tokens_mean=stats["effective_tokens_mean"], effective_tokens_min=float(lengths.min().item()), effective_tokens_max=float(lengths.max().item()), zero_length_count=int((lengths == 0).sum().item()), one_token_count=stats["one_token_count"], paren_then_eos_count=stats["paren_then_eos_count"], repeat_loop_count=stats["repeat_loop_count"], eos_terminated_rate=stats["eos_terminated_rate"], degenerate_count=stats["degenerate_count"], degenerate_rate=stats["degenerate_rate"], max_token_run_mean=stats["max_token_run_mean"], unique_token_ratio_mean=stats["unique_token_ratio_mean"], answer_flag_exactly_once_rate=stats["answer_flag_exactly_once_rate"], clipped_count=stats["clipped_count"], clipped_rate=stats["clipped_rate"], tokenizer_eos_id=eos_id, tokenizer_pad_id=pad_id, tokenizer_bos_id=bos_id, generation_config=_generation_config_summary(generation_config), ) suspicious: list[str] = [] n = min(sample_count, len(completions)) for i in range(n): eff = int(lengths[i].item()) eidx = int(eos_idx[i].item()) ids_head = completion_ids[i, : min(16, completion_ids.size(1))].tolist() ids_all = completion_ids[i].tolist() decode_skip = completions[i] decode_keep = tok.decode(completion_ids[i], skip_special_tokens=False) first_tok = int(completion_ids[i, 0].item()) if completion_ids.size(1) > 0 else None first_is_eos = first_tok is not None and first_tok == eos_id flags: list[str] = [] patterns: list[str] = [] if eff <= 0: flags.append("ZERO_LEN") if eff == 1: flags.append("ONE_TOKEN") if not (decode_skip or "").strip(): flags.append("EMPTY_DECODE") if first_is_eos: flags.append("FIRST_IS_EOS") if eff <= 2 and first_tok == PAREN_TOKEN_ID: patterns.append("PAREN_THEN_EOS") if eff > 0: ids_eff = completion_ids[i, :eff].tolist() if _detect_repeat_loop(ids_eff): patterns.append("REPEAT_LOOP") if _detect_single_token_repeat(ids_eff): run_len, repeat_tok_id = _max_same_token_run(ids_eff) patterns.append(f"SINGLE_TOKEN_REPEAT({run_len}x{repeat_tok_id})") is_deg, deg_reasons = _detect_degeneration(ids_eff, decode_skip, answer_flag=answer_flag) if is_deg: patterns.extend(deg_reasons) if flags: suspicious.append(f"sample[{i}]:{','.join(flags)}") opsd_debug.log_probe( "generate", f"sample[{i}]", group=i // num_generations if num_generations else 0, effective_tokens=eff, eos_idx=eidx, has_eos=bool(has_eos[i].item()), first_token_id=first_tok, first_is_eos=first_is_eos, token_ids_head=ids_head, token_ids_all=ids_all if len(ids_all) <= 32 else ids_head + ["..."], decode_skip_special=_preview_text(decode_skip, 400), decode_keep_special=_preview_text(decode_keep, 400), flags=flags or None, patterns=patterns or None, ) if suspicious: opsd_debug.log_probe( "generate", "ALERT suspicious completions", count=len(suspicious), samples=suspicious, hint="check eos_token_id, max_new_tokens, model.generate output, or training-time unwrap/FSDP", ) if stats.get("degenerate_rate", 0) >= 0.25: opsd_debug.log_probe( "generate", "ALERT repetition / format degeneration", degenerate_count=stats["degenerate_count"], degenerate_rate=stats["degenerate_rate"], clipped_rate=stats["clipped_rate"], answer_flag_exactly_once_rate=stats["answer_flag_exactly_once_rate"], max_token_run_mean=stats["max_token_run_mean"], hint="typical RL collapse: raise repetition_penalty, lower temperature, or shorten max_completion_length", ) return stats def log_cross_rank_generate_summary( *, accelerator: Any, one_token_count: int, effective_tokens_mean: float, generate_call_index: int, ) -> None: """Gather per-rank generate stats; all ranks must call, log on rank 0 only.""" if not opsd_debug.probe_on_generate(): return from accelerate.utils import gather as accel_gather device = accelerator.device world_size = accelerator.num_processes local = torch.tensor( [float(one_token_count), effective_tokens_mean], dtype=torch.float32, device=device, ) # gather_for_metrics on 1D [2] concatenates ranks into [world_size*2]; use gather + view instead. gathered = accel_gather(local.unsqueeze(0)) if not opsd_debug.should_log_gendbg(): return if gathered.dim() == 1: gathered = gathered.view(world_size, -1) elif gathered.size(0) != world_size and gathered.numel() == world_size * 2: gathered = gathered.reshape(world_size, 2) one_tokens = gathered[:, 0].tolist() means = gathered[:, 1].tolist() opsd_debug.log_gendbg( "cross_rank", "generate summary across ranks", generate_call_index=generate_call_index, world_size=world_size, one_token_count_per_rank=one_tokens, effective_tokens_mean_per_rank=means, one_token_count_total=int(sum(one_tokens)), ) def log_routed_completion_probe( *, global_step: int, trainer_step: Optional[int], raw_completion_shape: tuple[int, ...], final_completion_ids: torch.Tensor, final_completion_mask: torch.Tensor, opsd_mask_list: list[bool], sample_count: int = 4, tokenizer: Any, sft_replaced_list: Optional[list[bool]] = None, raw_completion_ids: Optional[torch.Tensor] = None, ) -> None: """Compare raw vs post-routing padded completion shapes (detect routing truncation).""" if not opsd_debug.should_log_probe(): return with torch.no_grad(): final_lengths = final_completion_mask.sum(dim=1).float() opsd_debug.log_probe( "route", "post-routing completion shapes", trainer_step=trainer_step, global_step=global_step, raw_completion_shape=raw_completion_shape, final_completion_shape=tuple(final_completion_ids.shape), final_mask_shape=tuple(final_completion_mask.shape), final_effective_tokens_mean=float(final_lengths.mean().item()), final_effective_tokens_min=float(final_lengths.min().item()), final_effective_tokens_max=float(final_lengths.max().item()), opsd_mask_true=sum(opsd_mask_list), opsd_mask_false=len(opsd_mask_list) - sum(opsd_mask_list), sft_replaced_count=sum(sft_replaced_list) if sft_replaced_list else None, ) n = min(sample_count, final_completion_ids.size(0)) for i in range(n): eff = int(final_lengths[i].item()) head = final_completion_ids[i, : min(12, final_completion_ids.size(1))].tolist() text = tokenizer.decode( final_completion_ids[i][final_completion_mask[i].bool()], skip_special_tokens=True, ) raw_head = None raw_matches_final = None if raw_completion_ids is not None and i < raw_completion_ids.size(0): raw_eff = min(int(raw_completion_ids.size(1)), eff) raw_head = raw_completion_ids[i, : min(12, raw_eff)].tolist() raw_matches_final = raw_head == head if raw_head and head else None opsd_debug.log_probe( "route", f"routed_sample[{i}]", opsd_mask=opsd_mask_list[i] if i < len(opsd_mask_list) else None, sft_replaced=sft_replaced_list[i] if sft_replaced_list and i < len(sft_replaced_list) else None, effective_tokens=eff, token_ids_head=head, raw_token_ids_head=raw_head, raw_matches_routed_head=raw_matches_final, decode=_preview_text(text, 300), ) def jsd_token_stats( student_logits: torch.Tensor, teacher_logits: torch.Tensor, mask: torch.Tensor, beta: float = 0.5, top_k: int = 5, ) -> dict[str, Any]: """Token-level JSD breakdown without building the graph.""" with torch.no_grad(): student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) student_probs = student_log_probs.exp() teacher_probs = teacher_log_probs.exp() if beta == 0: jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) elif beta == 1: jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) else: beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) mixture_log_probs = torch.logsumexp( torch.stack([student_log_probs + torch.log1p(-beta_t), teacher_log_probs + torch.log(beta_t)]), dim=0, ) kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) jsd = beta_t * kl_teacher + (1 - beta_t) * kl_student per_token_jsd = jsd.sum(dim=-1) m = mask.float() valid = m.sum().clamp(min=1.0) per_token_jsd_masked = per_token_jsd * m # Top-1 agreement on completion tokens s_top = student_logits.argmax(dim=-1) t_top = teacher_logits.argmax(dim=-1) agree = ((s_top == t_top).float() * m).sum() / valid # Mean L2 distance of log-probs on gold completion tokens (if provided elsewhere) logprob_l2 = ((student_log_probs - teacher_log_probs) ** 2).sum(dim=-1) logprob_l2_masked = (logprob_l2 * m).sum() / valid # Entropy gap s_ent = -(student_probs * student_log_probs).sum(dim=-1) t_ent = -(teacher_probs * teacher_log_probs).sum(dim=-1) ent_gap = ((s_ent - t_ent).abs() * m).sum() / valid stats: dict[str, Any] = { "jsd_per_token_mean": float((per_token_jsd_masked.sum() / valid).item()), "jsd_per_token_max": float(per_token_jsd[m.bool()].max().item()) if m.any() else 0.0, "jsd_per_token_min": float(per_token_jsd[m.bool()].min().item()) if m.any() else 0.0, "top1_agreement": float(agree.item()), "logprob_l2_mean": float(logprob_l2_masked.item()), "entropy_gap_mean": float(ent_gap.item()), "mask_valid_tokens": int(valid.item()), } if m.any(): idx = int((per_token_jsd * m).argmax().item()) pos = idx % per_token_jsd.size(-1) stats["max_jsd_token_pos"] = pos stats["max_jsd_value"] = float(per_token_jsd.reshape(-1)[idx].item()) s_topk = torch.topk(student_probs[0, pos], k=min(top_k, student_probs.size(-1))) t_topk = torch.topk(teacher_probs[0, pos], k=min(top_k, teacher_probs.size(-1))) stats["student_topk_at_max_jsd"] = [ (int(i), float(p)) for i, p in zip(s_topk.indices.tolist(), s_topk.values.tolist()) ] stats["teacher_topk_at_max_jsd"] = [ (int(i), float(p)) for i, p in zip(t_topk.indices.tolist(), t_topk.values.tolist()) ] return stats def log_generation_diagnostics( *, global_step: int, completions: list[str], completion_ids: torch.Tensor, completion_mask: torch.Tensor, is_eos: torch.Tensor, max_completion_length: int, num_generations: int, sample_count: int = 4, ) -> None: if not opsd_debug.should_log_detail(global_step): return opsd_debug.log_detail_banner(global_step, "GENERATION & COMPLETION") with torch.no_grad(): lengths = completion_mask.sum(dim=1).float() has_eos = is_eos.any(dim=1) eos_rate = has_eos.float().mean().item() clipped = (~has_eos).float().mean().item() all_pad_or_zero = (lengths == 0).float().mean().item() at_max_len = (lengths >= max_completion_length - 1).float().mean().item() fields: dict[str, Any] = { "batch_size": completion_ids.size(0), "num_generations": num_generations, "completion_max_len": completion_ids.size(1), "effective_tokens_mean": float(lengths.mean().item()), "effective_tokens_min": float(lengths.min().item()), "effective_tokens_max": float(lengths.max().item()), "eos_terminated_rate": eos_rate, "clipped_no_eos_rate": clipped, "zero_length_rate": all_pad_or_zero, "at_max_length_rate": at_max_len, } opsd_debug.log_detail("generation", "completion mask summary", **fields) n = min(sample_count, len(completions)) for i in range(n): gen_group = i // num_generations if num_generations else 0 opsd_debug.log_detail( "generation", f"sample[{i}]", group=gen_group, effective_tokens=int(lengths[i].item()), has_eos=bool(has_eos[i].item()), text=_preview_text(completions[i]), raw_token_head=completion_ids[i, :12].tolist(), ) def log_reward_diagnostics( *, global_step: int, format_rewards: torch.Tensor, acc_rewards: torch.Tensor, context_rewards: torch.Tensor, all_rewards: torch.Tensor, advantages: torch.Tensor, reward_weights: torch.Tensor, num_generations: int, answers: Optional[list[str]] = None, completions: Optional[list[str]] = None, sample_count: int = 4, ) -> None: if not opsd_debug.should_log_detail(global_step): return opsd_debug.log_detail_banner(global_step, "REWARD & ADVANTAGE") w = reward_weights.detach().float() weighted = ( format_rewards * w[0] + context_rewards * w[1] + acc_rewards * w[2] ) fields: dict[str, Any] = { "reward_weights": w.tolist(), "format_sum": float(format_rewards.sum().item()), "acc_sum": float(acc_rewards.sum().item()), "context_sum": float(context_rewards.sum().item()), "total_sum": float(all_rewards.sum().item()), "format_zero_rate": float((format_rewards == 0).float().mean().item()), "acc_zero_rate": float((acc_rewards == 0).float().mean().item()), "context_zero_rate": float((context_rewards == 0).float().mean().item()), "weighted_mean": float(weighted.mean().item()), "weighted_std": float(weighted.std(unbiased=False).item()) if weighted.numel() > 1 else 0.0, } adv_flat = advantages.reshape(-1) fields.update( { "advantage_mean": float(adv_flat.mean().item()), "advantage_std": float(adv_flat.std(unbiased=False).item()) if adv_flat.numel() > 1 else 0.0, "advantage_zero_rate": float((adv_flat.abs() < 1e-8).float().mean().item()), "advantage_min": float(adv_flat.min().item()), "advantage_max": float(adv_flat.max().item()), } ) opsd_debug.log_detail("reward", "aggregate reward stats", **fields) n = min(sample_count, format_rewards.numel()) for i in range(n): g = i // num_generations if num_generations else 0 extra: dict[str, Any] = { "group": g, "format": float(format_rewards.view(-1)[i].item()), "acc": float(acc_rewards.view(-1)[i].item()), "context": float(context_rewards.view(-1)[i].item()), "weighted": float(weighted.view(-1)[i].item()), "advantage": float(adv_flat.view(-1)[i].item()) if i < adv_flat.numel() else None, } if answers and i < len(answers): extra["gold_answer"] = _preview_text(str(answers[i // num_generations]), 80) if completions and i < len(completions): extra["completion"] = _preview_text(completions[i], 160) opsd_debug.log_detail("reward", f"per_sample[{i}]", **extra) def log_routing_diagnostics( *, global_step: int, opsd_active: bool, opsd_mask_list: list[bool], has_correct: torch.Tensor, completion_modes: Optional[list[int]] = None, recoverable_flags: Optional[list[bool]] = None, completion_advantages: Optional[torch.Tensor] = None, completion_mask: Optional[torch.Tensor] = None, ) -> None: if not opsd_debug.should_log_detail(global_step): return opsd_debug.log_detail_banner(global_step, "OPSD ROUTING & MASK") opsd_true = sum(opsd_mask_list) fields: dict[str, Any] = { "opsd_active": opsd_active, "opsd_mask_true": opsd_true, "opsd_mask_false": len(opsd_mask_list) - opsd_true, "opsd_mask_ratio": opsd_true / max(len(opsd_mask_list), 1), "has_correct": has_correct.tolist() if hasattr(has_correct, "tolist") else has_correct, } if recoverable_flags is not None: fields["recoverable_flags"] = recoverable_flags if completion_modes is not None: from opsd_utils.debug_log import MODE_NAMES counts = {MODE_NAMES.get(c, str(c)): completion_modes.count(c) for c in set(completion_modes)} fields["completion_mode_counts"] = counts if completion_advantages is not None and completion_mask is not None: with torch.no_grad(): pos = ((completion_advantages > 0) & (completion_mask > 0)).float().sum(dim=1) neg = ((completion_advantages < 0) & (completion_mask > 0)).float().sum(dim=1) zero_adv = ((completion_advantages.abs() < 1e-8) & (completion_mask > 0)).float().sum(dim=1) fields["adv_pos_tokens_mean"] = float(pos.mean().item()) fields["adv_neg_tokens_mean"] = float(neg.mean().item()) fields["adv_zero_tokens_mean"] = float(zero_adv.mean().item()) opsd_debug.log_detail("routing", "mode routing summary", **fields) def log_loss_diagnostics( *, global_step: int, grpo_loss: torch.Tensor, per_token_logps: torch.Tensor, old_per_token_logps: torch.Tensor, completion_mask: torch.Tensor, advantages: torch.Tensor, coef_1: torch.Tensor, per_token_loss: torch.Tensor, opsd_loss: Optional[torch.Tensor] = None, combined_loss: Optional[torch.Tensor] = None, opsd_mask: Optional[torch.Tensor] = None, epsilon_low: float = 0.2, epsilon_high: float = 0.2, ) -> None: if not opsd_debug.should_log_detail(global_step): return opsd_debug.log_detail_banner(global_step, "LOSS & GRADIENT SIGNAL") with torch.no_grad(): m = completion_mask.float() valid_per_sample = m.sum(dim=1).clamp(min=1.0) sample_grpo = (per_token_loss * m).sum(dim=1) / valid_per_sample fields: dict[str, Any] = { "grpo_loss_scalar": float(grpo_loss.detach().item()), "grpo_per_sample_mean": float(sample_grpo.mean().item()), "grpo_per_sample_max": float(sample_grpo.max().item()), "completion_mask_tokens_mean": float(valid_per_sample.mean().item()), } fields.update(_tensor_stats(advantages, "advantages")) fields.update(_tensor_stats(per_token_logps, "per_token_logps")) fields.update(_tensor_stats(old_per_token_logps, "old_per_token_logps")) fields.update(_tensor_stats((per_token_logps - old_per_token_logps) * m, "logps_delta_masked")) fields.update(_tensor_stats(coef_1 * m, "coef_1_masked")) low_clip = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0) & (m > 0)).float().sum() high_clip = ((coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0) & (m > 0)).float().sum() fields["clipped_low_tokens"] = int(low_clip.item()) fields["clipped_high_tokens"] = int(high_clip.item()) fields["grpo_zero_loss_rate"] = float((sample_grpo.abs() < 1e-12).float().mean().item()) if opsd_loss is not None: fields["opsd_loss_scalar"] = float(opsd_loss.detach().item()) if combined_loss is not None: fields["combined_loss_scalar"] = float(combined_loss.detach().item()) if opsd_mask is not None: fields["opsd_batch_count"] = int(opsd_mask.sum().item()) # Weak-signal hints hints: list[str] = [] if fields.get("advantages/abs_mean", 1.0) < 1e-6: hints.append("advantages≈0 → GRPO per-token loss≈0") if fields.get("grpo_zero_loss_rate", 0) > 0.9: hints.append("most samples have ~0 GRPO loss") if fields.get("completion_mask_tokens_mean", 0) < 2: hints.append("very few effective completion tokens in mask") if opsd_loss is not None and abs(fields.get("opsd_loss_scalar", 0)) < 1e-8: hints.append("OPSD JSD≈0 → student/teacher nearly identical on completion") fields["weak_signal_hints"] = hints or ["none"] opsd_debug.log_detail("loss", "GRPO / OPSD loss breakdown", **fields) def summarize_loss_health( *, grpo_loss: torch.Tensor, per_token_logps: torch.Tensor, completion_mask: torch.Tensor, advantages: torch.Tensor, per_token_loss: torch.Tensor, opsd_loss: Optional[torch.Tensor] = None, combined_loss: Optional[torch.Tensor] = None, ) -> dict[str, float]: """Lightweight loss/signal summary for every-step health monitoring.""" with torch.no_grad(): m = completion_mask.float() valid_per_sample = m.sum(dim=1).clamp(min=1.0) sample_grpo = (per_token_loss * m).sum(dim=1) / valid_per_sample adv_flat = advantages.detach().float().reshape(-1) fields: dict[str, float] = { "grpo_loss_scalar": float(grpo_loss.detach().item()), "grpo_zero_loss_rate": float((sample_grpo.abs() < 1e-12).float().mean().item()), "advantages_abs_mean": float(adv_flat.abs().mean().item()) if adv_flat.numel() else 0.0, "completion_mask_tokens_mean": float(valid_per_sample.mean().item()), } if opsd_loss is not None: fields["opsd_loss_scalar"] = float(opsd_loss.detach().item()) if combined_loss is not None: fields["combined_loss_scalar"] = float(combined_loss.detach().item()) logps_delta = (per_token_logps * m).sum() / m.sum().clamp(min=1.0) fields["logps_delta_mean"] = float(logps_delta.item()) if m.sum() > 0 else 0.0 return fields def summarize_batch_data_health( samples: list[dict[str, Any]], *, prompt_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, ) -> dict[str, Any]: """Batch-level data I/O sanity for health monitoring.""" n = max(len(samples), 1) vf_empty = 0 prompt_lens: list[int] = [] for sample in samples: vf = ( sample.get("visual_fact_hint") or sample.get("visual_fact") or sample.get("visual_facts") or "" ) if not str(vf).strip(): vf_empty += 1 if sample.get("prompt"): prompt_lens.append(len(str(sample["prompt"]))) out: dict[str, Any] = { "visual_fact_empty_rate": vf_empty / n, "batch_size": len(samples), "prompt_len_mean": float(sum(prompt_lens) / len(prompt_lens)) if prompt_lens else 0.0, } if prompt_mask is not None: with torch.no_grad(): lengths = prompt_mask.sum(dim=1).float() out["prompt_tokens_mean"] = float(lengths.mean().item()) out["prompt_tokens_max"] = float(lengths.max().item()) if pixel_values is not None and isinstance(pixel_values, torch.Tensor) and pixel_values.numel() > 0: with torch.no_grad(): flat = pixel_values.detach().float().reshape(-1) out["pixel_mean"] = float(flat.mean().item()) out["pixel_has_nan"] = bool(torch.isnan(flat).any().item()) out["pixel_has_inf"] = bool(torch.isinf(flat).any().item()) return out def log_opsd_jsd_diagnostics(*, global_step: int) -> None: """Emit cached JSD stats recorded during OPSD loss (zero extra model forwards).""" if not opsd_debug.should_log_detail(global_step): return capture = _OPSD_JSD_DETAIL_CAPTURE if capture.get("global_step") != global_step: return opsd_debug.log_detail_banner(global_step, "OPSD JSD DECOMPOSITION") if capture.get("skipped_memory"): opsd_debug.log_detail( "opsd_jsd", "JSD detail skipped (CUDA memory guard)", global_step=global_step, reason=capture.get("skip_reason", ""), min_free_gib=_detail_min_free_gib(), cuda_free_gib=cuda_free_gib(), ) _OPSD_JSD_DETAIL_CAPTURE["active"] = False return entries = capture.get("entries") or [] if not entries: opsd_debug.log_detail( "opsd_jsd", "no JSD detail captured (no OPSD samples on this rank or capture disabled)", global_step=global_step, ) _OPSD_JSD_DETAIL_CAPTURE["active"] = False return for k, stats in enumerate(entries): opsd_debug.log_detail("opsd_jsd", f"sample[{k}] token-level JSD", global_step=global_step, **stats) _OPSD_JSD_DETAIL_CAPTURE["active"] = False