| | import os |
| | from typing import Any, Dict, Optional |
| |
|
| |
|
| | _WANDB_AVAILABLE = False |
| | _WANDB_RUN = None |
| |
|
| |
|
| | def _try_import_wandb(): |
| | global _WANDB_AVAILABLE |
| | if _WANDB_AVAILABLE: |
| | return True |
| | try: |
| | import wandb |
| |
|
| | _WANDB_AVAILABLE = True |
| | return True |
| | except Exception: |
| | _WANDB_AVAILABLE = False |
| | return False |
| |
|
| |
|
| | def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any: |
| | cur: Any = cfg |
| | for key in path: |
| | if not isinstance(cur, dict) or key not in cur: |
| | return default |
| | cur = cur[key] |
| | return cur |
| |
|
| |
|
| | def is_enabled(cfg: Dict[str, Any]) -> bool: |
| | return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False)) |
| |
|
| |
|
| | def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None: |
| | """ |
| | Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed. |
| | """ |
| | global _WANDB_RUN |
| | if not is_enabled(cfg): |
| | return |
| | if not _try_import_wandb(): |
| | return |
| |
|
| | import wandb |
| |
|
| | project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation") |
| | entity = _safe_get(cfg, ["logging", "wandb", "entity"], None) |
| | mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online") |
| | tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or [] |
| | notes = _safe_get(cfg, ["logging", "wandb", "notes"], None) |
| | group = _safe_get(cfg, ["logging", "wandb", "group"], None) |
| | name = _safe_get(cfg, ["logging", "wandb", "name"], run_name) |
| |
|
| | |
| | os.makedirs(run_dir, exist_ok=True) |
| | os.environ.setdefault("WANDB_DIR", run_dir) |
| |
|
| | |
| | try: |
| | from omegaconf import OmegaConf |
| |
|
| | cfg_container = OmegaConf.to_container(cfg, resolve=True) |
| | except Exception: |
| | cfg_container = cfg |
| |
|
| | _WANDB_RUN = wandb.init( |
| | project=project, |
| | entity=entity, |
| | mode=mode, |
| | name=name, |
| | group=group, |
| | tags=tags, |
| | notes=notes, |
| | config=cfg_container, |
| | dir=run_dir, |
| | reinit=True, |
| | ) |
| |
|
| |
|
| | def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None: |
| | """Log a flat dictionary of metrics to W&B if active.""" |
| | if not _WANDB_AVAILABLE or _WANDB_RUN is None: |
| | return |
| | try: |
| | import wandb |
| |
|
| | wandb.log(metrics if step is None else dict(metrics, step=step)) |
| | except Exception: |
| | pass |
| |
|
| |
|
| | def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None: |
| | for k, v in data.items(): |
| | key = f"{prefix}.{k}" if prefix else k |
| | if isinstance(v, dict): |
| | _flatten(key, v, out) |
| | else: |
| | out[key] = v |
| |
|
| |
|
| | def _summarize_value(value: Any) -> Dict[str, Any]: |
| | import numpy as np |
| |
|
| | if value is None: |
| | return {"none": 1} |
| | |
| | if isinstance(value, (int, float)): |
| | return {"value": float(value)} |
| | |
| | try: |
| | arr = np.asarray(value) |
| | if arr.size == 0: |
| | return {"size": 0} |
| | return { |
| | "mean": float(np.nanmean(arr)), |
| | "min": float(np.nanmin(arr)), |
| | "max": float(np.nanmax(arr)), |
| | "last": float(arr.reshape(-1)[-1]), |
| | "size": int(arr.size), |
| | } |
| | except Exception: |
| | |
| | return {"text": str(value)} |
| |
|
| |
|
| | def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: |
| | """ |
| | Flatten and summarize Tally.array_tally and log to WandB. |
| | Each leaf list/array is summarized with mean/min/max/last/size. |
| | """ |
| | if not _WANDB_AVAILABLE or _WANDB_RUN is None: |
| | return |
| | summarized: Dict[str, Any] = {} |
| |
|
| | def walk(node: Any, path: list[str]): |
| | if isinstance(node, dict): |
| | for k, v in node.items(): |
| | walk(v, path + [k]) |
| | return |
| | |
| | key = ".".join([p for p in ([prefix] if prefix else []) + path]) |
| | try: |
| | summary = _summarize_value(node) |
| | for sk, sv in summary.items(): |
| | summarized[f"{key}.{sk}"] = sv |
| | except Exception: |
| | summarized[f"{key}.error"] = 1 |
| |
|
| | walk(array_tally, []) |
| | if summarized: |
| | log(summarized, step=step) |
| |
|
| |
|
| | def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: |
| | if not _WANDB_AVAILABLE or _WANDB_RUN is None: |
| | return |
| | flat: Dict[str, Any] = {} |
| | _flatten(prefix, stats, flat) |
| | if flat: |
| | log(flat, step=step) |
| |
|
| |
|
| |
|