| import time |
| import math |
|
|
|
|
| def _fmt_hms(seconds: float) -> str: |
| s = int(max(seconds, 0)) |
| h = s // 3600 |
| m = (s % 3600) // 60 |
| ss = s % 60 |
| return f"{h:02d}:{m:02d}:{ss:02d}" |
|
|
|
|
| def print_step_timing(sampler_name: str, step_index: int, start_time: float, total_steps: int): |
| """Standard per-step timing line printed at the start of each step.""" |
| steps_done = step_index |
| elapsed = time.time() - start_time |
| avg = elapsed / max(steps_done, 1) if steps_done > 0 else 0.0 |
| remaining = max(total_steps - steps_done, 0) |
| eta = avg * remaining |
| print(f"\n{sampler_name} step {step_index}: time elapsed {_fmt_hms(elapsed)} | time left {_fmt_hms(eta)}") |
|
|
|
|
| def _fmt(val, fmt: str = ".4g") -> str: |
| try: |
| if val is None: |
| return "-" |
| if hasattr(val, "item"): |
| v = float(val.item()) |
| else: |
| v = float(val) |
| if not math.isfinite(v): |
| return "nan" |
| return f"{v:{fmt}}" |
| except Exception: |
| return "-" |
|
|
|
|
| def print_step_diag( |
| sampler: str, |
| step_index: int, |
| sigma_current, |
| sigma_next, |
| *, |
| target_sigma=None, |
| sigma_up=None, |
| alpha_ratio=None, |
| h=None, |
| c2=None, |
| b1=None, |
| b2=None, |
| eps_norm=None, |
| eps_prev_norm=None, |
| x_rms=None, |
| flags: str = "", |
| ): |
| """Compact per-step diagnostics used when debug/verbose is enabled. |
| |
| Accepts whatever fields the caller has; missing ones are shown as '-'. |
| Safe to call from any sampler/model; prints a single concise line. |
| """ |
| parts = [ |
| f"{sampler} diag {step_index}:", |
| f"σ={_fmt(sigma_current)}→{_fmt(sigma_next)}", |
| ] |
| if target_sigma is not None: |
| parts.append(f"tgt={_fmt(target_sigma)}") |
| if h is not None: |
| parts.append(f"h={_fmt(h)}") |
| if c2 is not None: |
| parts.append(f"c2={_fmt(c2)}") |
| if b1 is not None or b2 is not None: |
| parts.append(f"b1={_fmt(b1)} b2={_fmt(b2)}") |
| if sigma_up is not None: |
| parts.append(f"up={_fmt(sigma_up)}") |
| try: |
| sn = float(sigma_next.item()) if hasattr(sigma_next, "item") else float(sigma_next) |
| su = float(sigma_up.item()) if hasattr(sigma_up, "item") else float(sigma_up) |
| if sn != 0: |
| parts.append(f"up/next={su/sn:.2f}") |
| except Exception: |
| pass |
| if alpha_ratio is not None: |
| parts.append(f"α={_fmt(alpha_ratio)}") |
| if eps_norm is not None: |
| if eps_prev_norm is not None: |
| parts.append(f"|ε|={_fmt(eps_norm)}({_fmt(eps_prev_norm)})") |
| else: |
| parts.append(f"|ε|={_fmt(eps_norm)}") |
| if x_rms is not None: |
| parts.append(f"x_rms={_fmt(x_rms)}") |
| if flags: |
| parts.append(f"[{flags}]") |
|
|
| |
| score = 0 |
| |
| try: |
| sn = float(sigma_next.item()) if hasattr(sigma_next, "item") else (float(sigma_next) if sigma_next is not None else None) |
| su = float(sigma_up.item()) if hasattr(sigma_up, "item") else (float(sigma_up) if sigma_up is not None else None) |
| if sn is not None and su is not None and sn > 0: |
| ratio = su / sn |
| if ratio > 0.8: |
| score += 2 |
| elif ratio > 0.5: |
| score += 1 |
| |
| try: |
| ar = float(alpha_ratio.item()) if hasattr(alpha_ratio, "item") else (float(alpha_ratio) if alpha_ratio is not None else None) |
| if ar is not None and ar >= 0.95 and ratio > 0.5: |
| score += 1 |
| except Exception: |
| pass |
| except Exception: |
| pass |
| |
| try: |
| if h is not None: |
| hv = abs(float(h.item()) if hasattr(h, "item") else float(h)) |
| if hv < 1e-6: |
| score += 2 |
| elif hv < 1e-3: |
| score += 1 |
| except Exception: |
| pass |
| try: |
| if c2 is not None: |
| c2v = abs(float(c2.item()) if hasattr(c2, "item") else float(c2)) |
| if c2v < 0.1 or c2v > 10.0: |
| score += 2 |
| elif c2v < 0.3 or c2v > 3.0: |
| score += 1 |
| except Exception: |
| pass |
| try: |
| if b1 is not None or b2 is not None: |
| b1v = abs(float(b1.item()) if hasattr(b1, "item") else float(b1) if b1 is not None else 0.0) |
| b2v = abs(float(b2.item()) if hasattr(b2, "item") else float(b2) if b2 is not None else 0.0) |
| s = b1v + b2v |
| if s > 10.0: |
| score += 2 |
| elif s > 5.0: |
| score += 1 |
| except Exception: |
| pass |
| try: |
| if eps_norm is not None and eps_prev_norm is not None: |
| en = float(eps_norm) |
| ep = float(eps_prev_norm) |
| if ep > 0: |
| rr = en / ep |
| if rr > 5.0: |
| score += 2 |
| elif rr > 2.0: |
| score += 1 |
| except Exception: |
| pass |
|
|
| risk = "LOW" if score <= 1 else ("MED" if score <= 3 else "HIGH") |
| parts.append(f"[RISK={risk}]") |
|
|
| print(" ".join(parts)) |
|
|