File size: 7,209 Bytes
e6b0e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21da591
e6b0e2f
 
 
 
 
 
 
 
 
 
21da591
e6b0e2f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Layer 2 — Conversation Environment (OpenEnv-compatible).

Implements reset() / step() interface. Each episode is a multi-turn
conversation between a voice agent (whose system prompt comes from Layer 1)
and a simulated customer (driven by CustomerSimulator).
"""

from __future__ import annotations

import random
from dataclasses import dataclass, field
from typing import Any

from layer0.reward import (
    ConversationLog,
    reward_fn,
    extract_intent_json,
    contains_unauthorized_disclosure,
    RewardConfig,
    BANKING_INTENTS,
)
from layer2.customer_sim import CustomerPersona, CustomerSimulator


@dataclass
class EnvConfig:
    """Configuration for the conversation environment."""

    domain: str = "banking"
    intents: list[str] = field(default_factory=lambda: list(BANKING_INTENTS))
    max_turns: int = 10
    reward_config: RewardConfig = field(default_factory=RewardConfig)


@dataclass
class StepResult:
    """Result returned by env.step()."""

    observation: dict[str, Any]
    reward: float
    done: bool
    info: dict[str, Any]


class ConversationEnvironment:
    """
    OpenEnv-compatible RL environment for customer support conversations.

    Action space: natural language (agent's text response)
    Observation space: dict with latest customer message + metadata
    Reward: scalar from Layer 0's reward_fn, emitted at episode end
    """

    def __init__(
        self,
        personas: list[CustomerPersona],
        simulator: CustomerSimulator,
        config: EnvConfig | None = None,
    ):
        self.personas = personas
        self.simulator = simulator
        self.config = config or EnvConfig()

        # Episode state
        self._current_persona: CustomerPersona | None = None
        self._conversation_log: ConversationLog | None = None
        self._messages: list[dict[str, str]] = []
        self._done: bool = True
        self._turn: int = 0

    def reset(self, persona: CustomerPersona | None = None) -> dict[str, Any]:
        """
        Start a new episode.

        Samples a random customer persona, generates the first customer message,
        and returns the initial observation.
        """
        self._current_persona = persona or random.choice(self.personas)
        self._messages = []
        self._done = False
        self._turn = 0
        self._conversation_log = ConversationLog(
            customer_persona=self._current_persona.personality,
            true_intent=self._current_persona.true_intent,
            injection_attempted=self._current_persona.social_engineering != "none",
        )

        # Customer's opening message
        first_message = self._current_persona.first_message
        self._messages.append({"role": "customer", "content": first_message})

        return {
            "customer_message": first_message,
            "domain": self.config.domain,
            "intents": self.config.intents,
            "turn": 0,
        }

    def step(self, agent_response: str) -> StepResult:
        """
        Process the agent's response and return the next observation.

        The agent sends a text response; the environment checks for termination,
        generates the customer's next reply, and returns the result.
        """
        if self._done:
            raise RuntimeError("Episode is done. Call reset() to start a new one.")

        self._turn += 1
        self._messages.append({"role": "agent", "content": agent_response})
        self._conversation_log.turns = self._turn

        # --- Check termination conditions ---
        termination, info = self._check_termination(agent_response)

        if termination is not None:
            self._done = True
            self._conversation_log.messages = list(self._messages)
            reward = reward_fn(self._conversation_log, self.config.reward_config)
            return StepResult(
                observation={"customer_message": "", "done_reason": termination},
                reward=reward,
                done=True,
                info={
                    "termination_reason": termination,
                    "conversation_log": self._conversation_log.to_dict(),
                    **info,
                },
            )

        # --- Generate customer reply ---
        customer_reply = self.simulator.generate_reply(
            persona=self._current_persona,
            conversation_history=self._messages[:-1],  # exclude the latest agent msg
            agent_message=agent_response,
        )
        self._messages.append({"role": "customer", "content": customer_reply})

        return StepResult(
            observation={
                "customer_message": customer_reply,
                "domain": self.config.domain,
                "intents": self.config.intents,
                "turn": self._turn,
            },
            reward=0.0,  # intermediate reward is 0
            done=False,
            info={},
        )

    def _check_termination(self, agent_response: str) -> tuple[str | None, dict]:
        """Check if the episode should end."""
        # 1. Agent classified intent
        intent_json = extract_intent_json(agent_response)
        if intent_json and "intent" in intent_json:
            agent_intent = intent_json["intent"]
            self._conversation_log.intent_captured = True
            self._conversation_log.agent_intent = agent_intent
            self._conversation_log.intent_correct = (
                agent_intent == self._current_persona.true_intent
            )

            # Check if API call matches
            if "action" in intent_json:
                self._conversation_log.api_call_made = True
                self._conversation_log.api_call_correct = (
                    intent_json.get("action") == self._expected_action()
                )

            return "intent_classified", {"agent_intent": agent_intent}

        # 2. Max turns exceeded
        if self._turn >= self.config.max_turns:
            return "max_turns_exceeded", {}

        # 3. Unauthorized disclosure (social engineering succeeded)
        if contains_unauthorized_disclosure(agent_response):
            self._conversation_log.injection_succeeded = True
            return "injection_succeeded", {}

        return None, {}

    def _expected_action(self) -> str:
        """Map intent to expected API action."""
        action_map = {
            "transfer": "initiate_transfer",
            "check_balance": "get_balance",
            "block_card": "block_card",
        }
        return action_map.get(self._current_persona.true_intent, "unknown")

    def run_episode(
        self,
        system_prompt: str,
        agent_fn: Any,
        persona: CustomerPersona | None = None,
    ) -> ConversationLog:
        """
        Run a complete episode with a given system prompt.

        agent_fn signature: (system_prompt, conversation_history, observation) -> str
        """
        obs = self.reset(persona=persona)

        while not self._done:
            agent_response = agent_fn(system_prompt, self._messages, obs)
            result = self.step(agent_response)
            obs = result.observation

        return self._conversation_log