File size: 3,757 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
ResolutionMemory — ε-greedy bandit over conflict resolution templates.

Tracks (conflict_type, template_key, quality_delta) outcomes and learns
which template produces the best quality improvements per conflict type.
No deep learning required — the arm count is small (4 types × N templates).
"""

from __future__ import annotations
import json
import random
from pathlib import Path
from dataclasses import dataclass, asdict


@dataclass
class ResolutionOutcome:
    conflict_type: str    # ConflictType.value string
    template_key:  str
    quality_delta: float  # specialist_score - baseline_score for the episode
    episode_idx:   int


class ResolutionBandit:
    """
    ε-greedy bandit that selects a resolution template for a given conflict type.
    Falls back to random selection until min_samples observations exist.

    Config keys (read from agents sub-dict of training config):
      resolution_bandit_epsilon       — exploration rate (default 0.15)
      resolution_bandit_min_samples   — minimum observations before exploiting (default 5)
    """

    def __init__(
        self,
        templates: dict[str, dict[str, str]],
        config: dict,
        memory_path: str,
    ):
        self._templates    = templates        # {ct_value_str: {template_key: template_str}}
        self._epsilon      = config.get("resolution_bandit_epsilon", 0.15)
        self._min_samples  = config.get("resolution_bandit_min_samples", 5)
        self._memory_path  = Path(memory_path)
        self._memory_path.parent.mkdir(parents=True, exist_ok=True)
        # {conflict_type_str: {template_key: [quality_deltas]}}
        self._stats: dict[str, dict[str, list[float]]] = {}
        self._load()

    def _load(self) -> None:
        if not self._memory_path.exists():
            return
        for line in self._memory_path.read_text().splitlines():
            try:
                rec = ResolutionOutcome(**json.loads(line))
                (self._stats
                 .setdefault(rec.conflict_type, {})
                 .setdefault(rec.template_key, [])
                 .append(rec.quality_delta))
            except Exception:
                continue

    def select_template(self, conflict_type_str: str) -> str:
        """
        ε-greedy selection over available templates for this conflict type.
        Returns the template key (not the template text).
        Falls back to the first available key if the type is unknown.
        """
        available = list(self._templates.get(conflict_type_str, {}).keys())
        if not available:
            return "default"

        type_stats = self._stats.get(conflict_type_str, {})
        if random.random() < self._epsilon or not type_stats:
            return random.choice(available)

        scored = {
            k: sum(v) / len(v)
            for k, v in type_stats.items()
            if k in available and len(v) >= self._min_samples
        }
        if not scored:
            return random.choice(available)
        return max(scored, key=scored.__getitem__)

    def record_outcome(self, outcome: ResolutionOutcome) -> None:
        (self._stats
         .setdefault(outcome.conflict_type, {})
         .setdefault(outcome.template_key, [])
         .append(outcome.quality_delta))
        with open(self._memory_path, "a") as f:
            f.write(json.dumps(asdict(outcome)) + "\n")

    def arm_means(self) -> dict[str, dict[str, float]]:
        """Return current mean quality delta per (conflict_type, template_key)."""
        return {
            ct: {
                tk: sum(deltas) / len(deltas)
                for tk, deltas in tk_map.items()
                if deltas
            }
            for ct, tk_map in self._stats.items()
        }