Muqeeth's picture
Add files using upload-large-folder tool
1c8c60e verified
import os
from typing import Any, Dict, Optional
_WANDB_AVAILABLE = False
_WANDB_RUN = None
def _try_import_wandb():
global _WANDB_AVAILABLE
if _WANDB_AVAILABLE:
return True
try:
import wandb # type: ignore
_WANDB_AVAILABLE = True
return True
except Exception:
_WANDB_AVAILABLE = False
return False
def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any:
cur: Any = cfg
for key in path:
if not isinstance(cur, dict) or key not in cur:
return default
cur = cur[key]
return cur
def is_enabled(cfg: Dict[str, Any]) -> bool:
return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False))
def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None:
"""
Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed.
"""
global _WANDB_RUN
if not is_enabled(cfg):
return
if not _try_import_wandb():
return
import wandb # type: ignore
project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation")
entity = _safe_get(cfg, ["logging", "wandb", "entity"], None)
mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online")
tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or []
notes = _safe_get(cfg, ["logging", "wandb", "notes"], None)
group = _safe_get(cfg, ["logging", "wandb", "group"], None)
name = _safe_get(cfg, ["logging", "wandb", "name"], run_name)
# Ensure files are written into the hydra run directory
os.makedirs(run_dir, exist_ok=True)
os.environ.setdefault("WANDB_DIR", run_dir)
# Convert cfg to plain types for W&B config; fallback to minimal dictionary
try:
from omegaconf import OmegaConf # type: ignore
cfg_container = OmegaConf.to_container(cfg, resolve=True) # type: ignore
except Exception:
cfg_container = cfg
_WANDB_RUN = wandb.init(
project=project,
entity=entity,
mode=mode,
name=name,
group=group,
tags=tags,
notes=notes,
config=cfg_container,
dir=run_dir,
reinit=True,
)
def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None:
"""Log a flat dictionary of metrics to W&B if active."""
if not _WANDB_AVAILABLE or _WANDB_RUN is None:
return
try:
import wandb # type: ignore
wandb.log(metrics if step is None else dict(metrics, step=step))
except Exception:
pass
def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None:
for k, v in data.items():
key = f"{prefix}.{k}" if prefix else k
if isinstance(v, dict):
_flatten(key, v, out)
else:
out[key] = v
def _summarize_value(value: Any) -> Dict[str, Any]:
import numpy as np # local import to avoid hard dependency during disabled mode
if value is None:
return {"none": 1}
# Scalars
if isinstance(value, (int, float)):
return {"value": float(value)}
# Lists or arrays
try:
arr = np.asarray(value)
if arr.size == 0:
return {"size": 0}
return {
"mean": float(np.nanmean(arr)),
"min": float(np.nanmin(arr)),
"max": float(np.nanmax(arr)),
"last": float(arr.reshape(-1)[-1]),
"size": int(arr.size),
}
except Exception:
# Fallback: string repr
return {"text": str(value)}
def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None:
"""
Flatten and summarize Tally.array_tally and log to WandB.
Each leaf list/array is summarized with mean/min/max/last/size.
"""
if not _WANDB_AVAILABLE or _WANDB_RUN is None:
return
summarized: Dict[str, Any] = {}
def walk(node: Any, path: list[str]):
if isinstance(node, dict):
for k, v in node.items():
walk(v, path + [k])
return
# node is a list of values accumulated over time
key = ".".join([p for p in ([prefix] if prefix else []) + path])
try:
summary = _summarize_value(node)
for sk, sv in summary.items():
summarized[f"{key}.{sk}"] = sv
except Exception:
summarized[f"{key}.error"] = 1
walk(array_tally, [])
if summarized:
log(summarized, step=step)
def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None:
if not _WANDB_AVAILABLE or _WANDB_RUN is None:
return
flat: Dict[str, Any] = {}
_flatten(prefix, stats, flat)
if flat:
log(flat, step=step)