File size: 4,800 Bytes
1c8c60e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
from typing import Any, Dict, Optional


_WANDB_AVAILABLE = False
_WANDB_RUN = None


def _try_import_wandb():
    global _WANDB_AVAILABLE
    if _WANDB_AVAILABLE:
        return True
    try:
        import wandb  # type: ignore

        _WANDB_AVAILABLE = True
        return True
    except Exception:
        _WANDB_AVAILABLE = False
        return False


def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any:
    cur: Any = cfg
    for key in path:
        if not isinstance(cur, dict) or key not in cur:
            return default
        cur = cur[key]
    return cur


def is_enabled(cfg: Dict[str, Any]) -> bool:
    return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False))


def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None:
    """
    Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed.
    """
    global _WANDB_RUN
    if not is_enabled(cfg):
        return
    if not _try_import_wandb():
        return

    import wandb  # type: ignore

    project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation")
    entity = _safe_get(cfg, ["logging", "wandb", "entity"], None)
    mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online")
    tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or []
    notes = _safe_get(cfg, ["logging", "wandb", "notes"], None)
    group = _safe_get(cfg, ["logging", "wandb", "group"], None)
    name = _safe_get(cfg, ["logging", "wandb", "name"], run_name)

    # Ensure files are written into the hydra run directory
    os.makedirs(run_dir, exist_ok=True)
    os.environ.setdefault("WANDB_DIR", run_dir)

    # Convert cfg to plain types for W&B config; fallback to minimal dictionary
    try:
        from omegaconf import OmegaConf  # type: ignore

        cfg_container = OmegaConf.to_container(cfg, resolve=True)  # type: ignore
    except Exception:
        cfg_container = cfg

    _WANDB_RUN = wandb.init(
        project=project,
        entity=entity,
        mode=mode,
        name=name,
        group=group,
        tags=tags,
        notes=notes,
        config=cfg_container,
        dir=run_dir,
        reinit=True,
    )


def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None:
    """Log a flat dictionary of metrics to W&B if active."""
    if not _WANDB_AVAILABLE or _WANDB_RUN is None:
        return
    try:
        import wandb  # type: ignore

        wandb.log(metrics if step is None else dict(metrics, step=step))
    except Exception:
        pass


def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None:
    for k, v in data.items():
        key = f"{prefix}.{k}" if prefix else k
        if isinstance(v, dict):
            _flatten(key, v, out)
        else:
            out[key] = v


def _summarize_value(value: Any) -> Dict[str, Any]:
    import numpy as np  # local import to avoid hard dependency during disabled mode

    if value is None:
        return {"none": 1}
    # Scalars
    if isinstance(value, (int, float)):
        return {"value": float(value)}
    # Lists or arrays
    try:
        arr = np.asarray(value)
        if arr.size == 0:
            return {"size": 0}
        return {
            "mean": float(np.nanmean(arr)),
            "min": float(np.nanmin(arr)),
            "max": float(np.nanmax(arr)),
            "last": float(arr.reshape(-1)[-1]),
            "size": int(arr.size),
        }
    except Exception:
        # Fallback: string repr
        return {"text": str(value)}


def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None:
    """
    Flatten and summarize Tally.array_tally and log to WandB.
    Each leaf list/array is summarized with mean/min/max/last/size.
    """
    if not _WANDB_AVAILABLE or _WANDB_RUN is None:
        return
    summarized: Dict[str, Any] = {}

    def walk(node: Any, path: list[str]):
        if isinstance(node, dict):
            for k, v in node.items():
                walk(v, path + [k])
            return
        # node is a list of values accumulated over time
        key = ".".join([p for p in ([prefix] if prefix else []) + path])
        try:
            summary = _summarize_value(node)
            for sk, sv in summary.items():
                summarized[f"{key}.{sk}"] = sv
        except Exception:
            summarized[f"{key}.error"] = 1

    walk(array_tally, [])
    if summarized:
        log(summarized, step=step)


def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None:
    if not _WANDB_AVAILABLE or _WANDB_RUN is None:
        return
    flat: Dict[str, Any] = {}
    _flatten(prefix, stats, flat)
    if flat:
        log(flat, step=step)