") # split-agent-context
+
+ last_time_step_chat = None
+ for original_index, turn in indexed_turns:
+ # Use agent index for CSS class (agent-0 or agent-1) instead of agent ID
+ agent_index = agent_id_to_index.get(turn.agent_id, 0)
+ agent_class = f"agent-{agent_index}"
+ role_class = f"role-{turn.role}"
+
+ # Add time step divider and beginning context
+ if last_time_step_chat is None or turn.time_step != last_time_step_chat:
+ # Add end contexts for previous round (only regular context, not prompt summary)
+ if last_time_step_chat is not None:
+ add_context_area("end", last_time_step_chat)
+
+ html_parts.append(
+ f'
'
+ f'⏱ Round {turn.time_step + 1}'
+ f"
"
+ )
+
+ # Add beginning contexts for new round (both context and prompt summary)
+ add_context_area("beginning", turn.time_step)
+ add_split_agent_contexts("beginning", turn.time_step)
+
+ last_time_step_chat = turn.time_step
+
+ # Build chat message with merge controls
+ html_parts.append(
+ f'
") # chat-message
+
+ # Add end contexts for the last round (only regular context, not prompt summary)
+ if last_time_step_chat is not None:
+ add_context_area("end", last_time_step_chat)
+
+ html_parts.append("
") # flow-chat
+ html_parts.extend(["", ""])
+
+ return "\n".join(html_parts)
+
+
+def export_html_from_rollout_tree(path: Path, outdir: Path, main_only: bool = False):
+ """Process a rollout tree file and generate HTML files for each path.
+ Creates separate HTML files for the main path and each branch path.
+ The main path is saved in the root output directory, while branch paths
+ are saved in a 'branches' subdirectory.
+
+ Args:
+ path: Path to the rollout tree JSON file
+ outdir: Output directory for HTML files
+ main_only: If True, only export the main trajectory (default: False)
+ """
+ root = load_rollout_tree(path)
+ mgid = root.id
+
+ main_path, branch_paths = get_rollout_tree_paths(root)
+
+ outdir.mkdir(parents=True, exist_ok=True)
+
+ # Create branches subdirectory if we have branch paths
+ if not main_only and branch_paths:
+ branches_dir = outdir / f"mgid:{mgid}_branches_html_renders"
+ branches_dir.mkdir(parents=True, exist_ok=True)
+
+ # Generate HTML for the main path
+ chat_turns = gather_all_chat_turns_for_path(main_path)
+ html_content = html_from_chat_turns(chat_turns)
+ output_file = outdir / f"mgid:{mgid}_main_html_render.render.html"
+ with open(output_file, "w", encoding="utf-8") as f:
+ f.write(html_content)
+
+ # Generate HTML for each branch path
+ for path_obj in branch_paths:
+ chat_turns = gather_all_chat_turns_for_path(path_obj)
+
+ html_content = html_from_chat_turns(chat_turns)
+
+ path_id: str = path_obj.id
+ output_filename = f"{path_id}_html_render.render.html"
+
+ output_file = branches_dir / output_filename
+
+ with open(output_file, "w", encoding="utf-8") as f:
+ f.write(html_content)
diff --git a/src_code_for_reproducibility/utils/rollout_tree_stats.py b/src_code_for_reproducibility/utils/rollout_tree_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..4725160156230d7efb89588c765fb5b63a7bbbe1
--- /dev/null
+++ b/src_code_for_reproducibility/utils/rollout_tree_stats.py
@@ -0,0 +1,55 @@
+"""
+File: mllm/utils/rollout_tree_stats.py
+Summary: Computes descriptive statistics from rollout tree collections.
+"""
+
+from typing import Any, Callable, List, Tuple
+
+from mllm.markov_games.rollout_tree import RolloutTreeRootNode
+from mllm.markov_games.simulation import SimulationStepLog
+from mllm.utils.rollout_tree_gather_utils import (
+ gather_simulation_step_logs,
+ get_rollout_tree_paths,
+)
+from mllm.utils.stat_pack import StatPack
+
+
+def get_rollout_tree_stat_tally(
+ rollout_tree: RolloutTreeRootNode,
+ metrics: List[Callable[[SimulationStepLog], List[Tuple[str, float]]]],
+) -> StatPack:
+ stat_tally = StatPack()
+ # get simulation step logs
+ node_list = get_rollout_tree_paths(rollout_tree)[0]
+ simulation_step_logs = gather_simulation_step_logs(node_list)
+ for simulation_step_log in simulation_step_logs:
+ for metric in metrics:
+ metric_result = metric(simulation_step_log)
+ if metric_result is not None:
+ for key, value in metric_result:
+ stat_tally.add_stat(key, value)
+ return stat_tally
+
+
+def get_rollout_tree_mean_stats(
+ rollout_tree: RolloutTreeRootNode, metrics: List[Callable[[SimulationStepLog], Any]]
+) -> StatPack:
+ """Get the mean stats for a rollout tree."""
+ stat_tally = get_rollout_tree_stat_tally(rollout_tree, metrics)
+ return stat_tally.mean()
+
+
+def get_mean_rollout_tree_stats(
+ rollout_trees: List[RolloutTreeRootNode],
+ metrics: List[Callable[[SimulationStepLog], Any]],
+) -> StatPack:
+ """Get the mean stats for a list of rollout trees."""
+ # Compute per-rollout means first, then aggregate them across the entire batch.
+ stat_tallies = [
+ get_rollout_tree_mean_stats(rollout_tree, metrics)
+ for rollout_tree in rollout_trees
+ ]
+ mean_stat_tally = StatPack()
+ for stat_tally in stat_tallies:
+ mean_stat_tally.add_stats(stat_tally)
+ return mean_stat_tally.mean()
diff --git a/src_code_for_reproducibility/utils/short_id_gen.py b/src_code_for_reproducibility/utils/short_id_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c08ffdc3362c767ea8916496ea5b0e1c01dbd7e
--- /dev/null
+++ b/src_code_for_reproducibility/utils/short_id_gen.py
@@ -0,0 +1,16 @@
+"""
+File: mllm/utils/short_id_gen.py
+Summary: Generates short unique identifiers for experiment assets.
+"""
+
+import uuid
+
+
+def generate_short_id() -> int:
+ """
+ Generates a short unique ID for tracking adapter versions.
+
+ Returns:
+ int: An 8-digit integer ID.
+ """
+ return int(str(uuid.uuid4().int)[:8])
diff --git a/src_code_for_reproducibility/utils/stat_pack.py b/src_code_for_reproducibility/utils/stat_pack.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4da475dafa8e3290ba9be10922be5687ac2c862
--- /dev/null
+++ b/src_code_for_reproducibility/utils/stat_pack.py
@@ -0,0 +1,117 @@
+"""
+File: mllm/utils/stat_pack.py
+Summary: Implements the StatPack container for incremental statistics.
+"""
+
+import csv
+import json
+import os
+import pickle
+from collections import Counter
+from copy import deepcopy
+from locale import strcoll
+from statistics import mean
+from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+style_path = os.environ.get("ADALIGN_MPLSTYLE")
+if style_path:
+ plt.style.use(style_path)
+
+import wandb
+
+from . import wandb_utils
+
+
+class StatPack:
+ def __init__(self):
+ self.data = {}
+
+ def add_stat(self, key: str, value: float | int | None):
+ assert (
+ isinstance(value, float) or isinstance(value, int) or value is None
+ ), f"Value {value} is not a valid type"
+ if key not in self.data:
+ self.data[key] = []
+ self.data[key].append(value)
+
+ def add_stats(self, other: "StatPack"):
+ for key in other.keys():
+ self.add_stat(key, other[key])
+
+ def __getitem__(self, key: str):
+ return self.data[key]
+
+ def __setitem__(self, key: str, value: Any):
+ self.data[key] = value
+
+ def __contains__(self, key: str):
+ return key in self.data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def keys(self):
+ return self.data.keys()
+
+ def values(self):
+ return self.data.values()
+
+ def items(self):
+ return self.data.items()
+
+ def mean(self):
+ mean_st = StatPack()
+ for key in self.keys():
+ if isinstance(self[key], list):
+ # Ignore None entries so missing measurements do not bias the mean.
+ non_none_values = [v for v in self[key] if v is not None]
+ if non_none_values:
+ mean_st[key] = np.mean(np.array(non_none_values))
+ else:
+ mean_st[key] = None
+ return mean_st
+
+ def store_plots(self, folder: str):
+ os.makedirs(folder, exist_ok=True)
+ for key in self.keys():
+ plt.figure(figsize=(10, 5))
+ plt.plot(self[key])
+ plt.title(key)
+ plt.savefig(os.path.join(folder, f"{key}.pdf"))
+ plt.close()
+
+ def store_numpy(self, folder: str):
+ os.makedirs(folder, exist_ok=True)
+ for key in self.keys():
+ # Sanitize filename components (avoid slashes, spaces, etc.)
+ safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_")
+ values = self[key]
+ # Convert None to NaN for numpy compatibility
+ arr = np.array(
+ [(np.nan if (v is None) else v) for v in values], dtype=float
+ )
+ np.save(os.path.join(folder, f"{safe_key}.npy"), arr)
+
+ def store_json(self, folder: str, filename: str = "stats.json"):
+ os.makedirs(folder, exist_ok=True)
+ with open(os.path.join(folder, filename), "w") as f:
+ json.dump(self.data, f, indent=4)
+
+ def store_csv(self, folder: str):
+ os.makedirs(folder, exist_ok=True)
+ for key in self.keys():
+ with open(os.path.join(folder, f"stats.csv"), "w") as f:
+ writer = csv.writer(f)
+ writer.writerow([key] + self[key])
+
+ def store_pickle(self, folder: str):
+ os.makedirs(folder, exist_ok=True)
+ for key in self.keys():
+ with open(os.path.join(folder, f"stats.pkl"), "wb") as f:
+ pickle.dump(self[key], f)
diff --git a/src_code_for_reproducibility/utils/wandb_utils.py b/src_code_for_reproducibility/utils/wandb_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..46289bfdbb48b72fa3fe3b531d150447cfc1eb01
--- /dev/null
+++ b/src_code_for_reproducibility/utils/wandb_utils.py
@@ -0,0 +1,170 @@
+"""
+File: mllm/utils/wandb_utils.py
+Summary: Shared Weights & Biases helper functions.
+"""
+
+import os
+from typing import Any, Dict, Optional
+
+_WANDB_AVAILABLE = False
+_WANDB_RUN = None
+
+
+def _try_import_wandb():
+ global _WANDB_AVAILABLE
+ if _WANDB_AVAILABLE:
+ return True
+ try:
+ import wandb # type: ignore
+
+ _WANDB_AVAILABLE = True
+ return True
+ except Exception:
+ _WANDB_AVAILABLE = False
+ return False
+
+
+def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any:
+ cur: Any = cfg
+ for key in path:
+ if not isinstance(cur, dict) or key not in cur:
+ return default
+ cur = cur[key]
+ return cur
+
+
+def is_enabled(cfg: Dict[str, Any]) -> bool:
+ return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False))
+
+
+def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None:
+ """
+ Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed.
+ """
+ global _WANDB_RUN
+ if not is_enabled(cfg):
+ return
+ if not _try_import_wandb():
+ return
+
+ import wandb # type: ignore
+
+ project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation")
+ entity = _safe_get(cfg, ["logging", "wandb", "entity"], None)
+ mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online")
+ tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or []
+ notes = _safe_get(cfg, ["logging", "wandb", "notes"], None)
+ group = _safe_get(cfg, ["logging", "wandb", "group"], None)
+ name = _safe_get(cfg, ["logging", "wandb", "name"], run_name)
+
+ # Ensure files are written into the hydra run directory
+ os.makedirs(run_dir, exist_ok=True)
+ os.environ.setdefault("WANDB_DIR", run_dir)
+
+ # Convert cfg to plain types for W&B config; fallback to minimal dictionary
+ try:
+ from omegaconf import OmegaConf # type: ignore
+
+ cfg_container = OmegaConf.to_container(cfg, resolve=True) # type: ignore
+ except Exception:
+ cfg_container = cfg
+
+ _WANDB_RUN = wandb.init(
+ project=project,
+ entity=entity,
+ mode=mode,
+ name=name,
+ group=group,
+ tags=tags,
+ notes=notes,
+ config=cfg_container,
+ dir=run_dir,
+ reinit=True,
+ )
+
+
+def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None:
+ """Log a flat dictionary of metrics to W&B if active."""
+ if not _WANDB_AVAILABLE or _WANDB_RUN is None:
+ return
+ try:
+ import wandb # type: ignore
+
+ wandb.log(metrics if step is None else dict(metrics, step=step))
+ except Exception:
+ pass
+
+
+def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None:
+ for k, v in data.items():
+ key = f"{prefix}.{k}" if prefix else k
+ if isinstance(v, dict):
+ _flatten(key, v, out)
+ else:
+ out[key] = v
+
+
+def _summarize_value(value: Any) -> Dict[str, Any]:
+ import numpy as np # local import to avoid hard dependency during disabled mode
+
+ if value is None:
+ return {"none": 1}
+ # Scalars
+ if isinstance(value, (int, float)):
+ return {"value": float(value)}
+ # Lists or arrays
+ try:
+ arr = np.asarray(value)
+ if arr.size == 0:
+ return {"size": 0}
+ return {
+ "mean": float(np.nanmean(arr)),
+ "min": float(np.nanmin(arr)),
+ "max": float(np.nanmax(arr)),
+ "last": float(arr.reshape(-1)[-1]),
+ "size": int(arr.size),
+ }
+ except Exception:
+ # Fallback: string repr
+ return {"text": str(value)}
+
+
+def log_tally(
+ array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None
+) -> None:
+ """
+ Flatten and summarize Tally.array_tally and log to WandB.
+ Each leaf list/array is summarized with mean/min/max/last/size.
+ """
+ if not _WANDB_AVAILABLE or _WANDB_RUN is None:
+ return
+ summarized: Dict[str, Any] = {}
+
+ def walk(node: Any, path: list[str]):
+ if isinstance(node, dict):
+ for k, v in node.items():
+ walk(v, path + [k])
+ return
+ # node is a list of values accumulated over time
+ key = ".".join([p for p in ([prefix] if prefix else []) + path])
+ try:
+ summary = _summarize_value(node)
+ for sk, sv in summary.items():
+ summarized[f"{key}.{sk}"] = sv
+ except Exception:
+ summarized[f"{key}.error"] = 1
+
+ walk(array_tally, [])
+ if summarized:
+ log(summarized, step=step)
+
+
+def log_flat_stats(
+ stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None
+) -> None:
+ if not _WANDB_AVAILABLE or _WANDB_RUN is None:
+ return
+ flat: Dict[str, Any] = {}
+ _flatten(prefix, stats, flat)
+ if flat:
+ log(flat, step=step)