Spaces:
Paused
Paused
| """ | |
| ResolutionMemory — ε-greedy bandit over conflict resolution templates. | |
| Tracks (conflict_type, template_key, quality_delta) outcomes and learns | |
| which template produces the best quality improvements per conflict type. | |
| No deep learning required — the arm count is small (4 types × N templates). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import random | |
| from pathlib import Path | |
| from dataclasses import dataclass, asdict | |
| class ResolutionOutcome: | |
| conflict_type: str # ConflictType.value string | |
| template_key: str | |
| quality_delta: float # specialist_score - baseline_score for the episode | |
| episode_idx: int | |
| class ResolutionBandit: | |
| """ | |
| ε-greedy bandit that selects a resolution template for a given conflict type. | |
| Falls back to random selection until min_samples observations exist. | |
| Config keys (read from agents sub-dict of training config): | |
| resolution_bandit_epsilon — exploration rate (default 0.15) | |
| resolution_bandit_min_samples — minimum observations before exploiting (default 5) | |
| """ | |
| def __init__( | |
| self, | |
| templates: dict[str, dict[str, str]], | |
| config: dict, | |
| memory_path: str, | |
| ): | |
| self._templates = templates # {ct_value_str: {template_key: template_str}} | |
| self._epsilon = config.get("resolution_bandit_epsilon", 0.15) | |
| self._min_samples = config.get("resolution_bandit_min_samples", 5) | |
| self._memory_path = Path(memory_path) | |
| self._memory_path.parent.mkdir(parents=True, exist_ok=True) | |
| # {conflict_type_str: {template_key: [quality_deltas]}} | |
| self._stats: dict[str, dict[str, list[float]]] = {} | |
| self._load() | |
| def _load(self) -> None: | |
| if not self._memory_path.exists(): | |
| return | |
| for line in self._memory_path.read_text().splitlines(): | |
| try: | |
| rec = ResolutionOutcome(**json.loads(line)) | |
| (self._stats | |
| .setdefault(rec.conflict_type, {}) | |
| .setdefault(rec.template_key, []) | |
| .append(rec.quality_delta)) | |
| except Exception: | |
| continue | |
| def select_template(self, conflict_type_str: str) -> str: | |
| """ | |
| ε-greedy selection over available templates for this conflict type. | |
| Returns the template key (not the template text). | |
| Falls back to the first available key if the type is unknown. | |
| """ | |
| available = list(self._templates.get(conflict_type_str, {}).keys()) | |
| if not available: | |
| return "default" | |
| type_stats = self._stats.get(conflict_type_str, {}) | |
| if random.random() < self._epsilon or not type_stats: | |
| return random.choice(available) | |
| scored = { | |
| k: sum(v) / len(v) | |
| for k, v in type_stats.items() | |
| if k in available and len(v) >= self._min_samples | |
| } | |
| if not scored: | |
| return random.choice(available) | |
| return max(scored, key=scored.__getitem__) | |
| def record_outcome(self, outcome: ResolutionOutcome) -> None: | |
| (self._stats | |
| .setdefault(outcome.conflict_type, {}) | |
| .setdefault(outcome.template_key, []) | |
| .append(outcome.quality_delta)) | |
| with open(self._memory_path, "a") as f: | |
| f.write(json.dumps(asdict(outcome)) + "\n") | |
| def arm_means(self) -> dict[str, dict[str, float]]: | |
| """Return current mean quality delta per (conflict_type, template_key).""" | |
| return { | |
| ct: { | |
| tk: sum(deltas) / len(deltas) | |
| for tk, deltas in tk_map.items() | |
| if deltas | |
| } | |
| for ct, tk_map in self._stats.items() | |
| } | |