File size: 5,203 Bytes
c6535db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import time
import math
def _fmt_hms(seconds: float) -> str:
s = int(max(seconds, 0))
h = s // 3600
m = (s % 3600) // 60
ss = s % 60
return f"{h:02d}:{m:02d}:{ss:02d}"
def print_step_timing(sampler_name: str, step_index: int, start_time: float, total_steps: int):
"""Standard per-step timing line printed at the start of each step."""
steps_done = step_index
elapsed = time.time() - start_time
avg = elapsed / max(steps_done, 1) if steps_done > 0 else 0.0
remaining = max(total_steps - steps_done, 0)
eta = avg * remaining
print(f"\n{sampler_name} step {step_index}: time elapsed {_fmt_hms(elapsed)} | time left {_fmt_hms(eta)}")
def _fmt(val, fmt: str = ".4g") -> str:
try:
if val is None:
return "-"
if hasattr(val, "item"):
v = float(val.item())
else:
v = float(val)
if not math.isfinite(v):
return "nan"
return f"{v:{fmt}}"
except Exception:
return "-"
def print_step_diag(
sampler: str,
step_index: int,
sigma_current,
sigma_next,
*,
target_sigma=None,
sigma_up=None,
alpha_ratio=None,
h=None,
c2=None,
b1=None,
b2=None,
eps_norm=None,
eps_prev_norm=None,
x_rms=None,
flags: str = "",
):
"""Compact per-step diagnostics used when debug/verbose is enabled.
Accepts whatever fields the caller has; missing ones are shown as '-'.
Safe to call from any sampler/model; prints a single concise line.
"""
parts = [
f"{sampler} diag {step_index}:",
f"σ={_fmt(sigma_current)}→{_fmt(sigma_next)}",
]
if target_sigma is not None:
parts.append(f"tgt={_fmt(target_sigma)}")
if h is not None:
parts.append(f"h={_fmt(h)}")
if c2 is not None:
parts.append(f"c2={_fmt(c2)}")
if b1 is not None or b2 is not None:
parts.append(f"b1={_fmt(b1)} b2={_fmt(b2)}")
if sigma_up is not None:
parts.append(f"up={_fmt(sigma_up)}")
try:
sn = float(sigma_next.item()) if hasattr(sigma_next, "item") else float(sigma_next)
su = float(sigma_up.item()) if hasattr(sigma_up, "item") else float(sigma_up)
if sn != 0:
parts.append(f"up/next={su/sn:.2f}")
except Exception:
pass
if alpha_ratio is not None:
parts.append(f"α={_fmt(alpha_ratio)}")
if eps_norm is not None:
if eps_prev_norm is not None:
parts.append(f"|ε|={_fmt(eps_norm)}({_fmt(eps_prev_norm)})")
else:
parts.append(f"|ε|={_fmt(eps_norm)}")
if x_rms is not None:
parts.append(f"x_rms={_fmt(x_rms)}")
if flags:
parts.append(f"[{flags}]")
# Compute a coarse risk tag (LOW/MED/HIGH) from available fields
score = 0
# Noise fraction risk
try:
sn = float(sigma_next.item()) if hasattr(sigma_next, "item") else (float(sigma_next) if sigma_next is not None else None)
su = float(sigma_up.item()) if hasattr(sigma_up, "item") else (float(sigma_up) if sigma_up is not None else None)
if sn is not None and su is not None and sn > 0:
ratio = su / sn
if ratio > 0.8:
score += 2
elif ratio > 0.5:
score += 1
# ODE-like mixing with large noise
try:
ar = float(alpha_ratio.item()) if hasattr(alpha_ratio, "item") else (float(alpha_ratio) if alpha_ratio is not None else None)
if ar is not None and ar >= 0.95 and ratio > 0.5:
score += 1
except Exception:
pass
except Exception:
pass
# Step size and geometry risk
try:
if h is not None:
hv = abs(float(h.item()) if hasattr(h, "item") else float(h))
if hv < 1e-6:
score += 2
elif hv < 1e-3:
score += 1
except Exception:
pass
try:
if c2 is not None:
c2v = abs(float(c2.item()) if hasattr(c2, "item") else float(c2))
if c2v < 0.1 or c2v > 10.0:
score += 2
elif c2v < 0.3 or c2v > 3.0:
score += 1
except Exception:
pass
try:
if b1 is not None or b2 is not None:
b1v = abs(float(b1.item()) if hasattr(b1, "item") else float(b1) if b1 is not None else 0.0)
b2v = abs(float(b2.item()) if hasattr(b2, "item") else float(b2) if b2 is not None else 0.0)
s = b1v + b2v
if s > 10.0:
score += 2
elif s > 5.0:
score += 1
except Exception:
pass
try:
if eps_norm is not None and eps_prev_norm is not None:
en = float(eps_norm)
ep = float(eps_prev_norm)
if ep > 0:
rr = en / ep
if rr > 5.0:
score += 2
elif rr > 2.0:
score += 1
except Exception:
pass
risk = "LOW" if score <= 1 else ("MED" if score <= 3 else "HIGH")
parts.append(f"[RISK={risk}]")
print(" ".join(parts))
|