File size: 4,782 Bytes
b259333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc48b7
b259333
 
 
3dc48b7
b259333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
OpenEnv 0.2.1 wrapper for the Conversation Environment.

Provides OpenEnv-compatible registration and standardized interface.
If the openenv package is installed, this registers the environment.
Otherwise, it provides a standalone wrapper with the same API contract.
"""

from __future__ import annotations

from layer2.environment import ConversationEnvironment, EnvConfig, StepResult
from layer2.customer_sim import CustomerPersona, CustomerSimulator
from layer0.reward import BANKING_INTENTS

# Environment metadata for OpenEnv registry
ENV_ID = "nested-rl/CustomerSupport-v0"
ENV_METADATA = {
    "id": ENV_ID,
    "description": (
        "Multi-turn customer support conversation environment. "
        "An agent must classify customer intent while resisting social engineering."
    ),
    "action_space": "text",
    "observation_space": {
        "customer_message": "str",
        "domain": "str",
        "intents": "list[str]",
        "turn": "int",
    },
    "reward_range": (-150.0, 130.0),
    "max_episode_steps": 10,
    "domain": "banking",
    "intents": BANKING_INTENTS,
}


class OpenEnvCustomerSupport:
    """
    OpenEnv 0.2.1 compatible environment wrapper.

    Wraps ConversationEnvironment with the standardized OpenEnv interface:
    - reset() -> observation
    - step(action) -> (observation, reward, terminated, truncated, info)
    - metadata property
    """

    metadata = ENV_METADATA

    def __init__(
        self,
        personas: list[CustomerPersona] | None = None,
        simulator: CustomerSimulator | None = None,
        config: EnvConfig | None = None,
        persona_count: int = 100,
    ):
        if personas is None:
            from personas.generate_personas import generate_personas
            personas_data = generate_personas(persona_count)
            personas = [CustomerPersona(**p) for p in personas_data]

        self._simulator = simulator or CustomerSimulator()
        self._env = ConversationEnvironment(
            personas=personas,
            simulator=self._simulator,
            config=config or EnvConfig(),
        )

    def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
        """
        Reset the environment.

        Args:
            seed: Random seed (for reproducibility)
            options: Optional dict with "persona_id" to select specific persona

        Returns:
            (observation, info)
        """
        import random
        if seed is not None:
            random.seed(seed)

        persona = None
        if options and "persona_id" in options:
            pid = options["persona_id"]
            if 0 <= pid < len(self._env.personas):
                persona = self._env.personas[pid]

        obs = self._env.reset(persona=persona)
        info = {
            "persona_id": self._env._current_persona.id,
            "social_engineering": self._env._current_persona.social_engineering,
            "complexity": self._env._current_persona.complexity,
        }
        return obs, info

    def step(self, action: str) -> tuple[dict, float, bool, bool, dict]:
        """
        Take a step in the environment.

        Args:
            action: Agent's text response

        Returns:
            (observation, reward, terminated, truncated, info)
            - terminated: episode ended due to classification or injection
            - truncated: episode ended due to max turns
        """
        result = self._env.step(action)

        terminated = False
        truncated = False
        if result.done:
            reason = result.info.get("termination_reason", "")
            if reason == "max_turns_exceeded":
                truncated = True
            else:
                terminated = True

        return result.observation, result.reward, terminated, truncated, result.info

    def close(self):
        """Clean up resources."""
        pass

    def render(self) -> str:
        """Render the current conversation as text."""
        if not self._env._messages:
            return "(no conversation in progress)"

        lines = []
        for msg in self._env._messages:
            role = "Customer" if msg["role"] == "customer" else "Agent"
            lines.append(f"[{role}] {msg['content']}")
        return "\n".join(lines)


def make_env(**kwargs) -> OpenEnvCustomerSupport:
    """Factory function for creating the environment (OpenEnv compatible)."""
    return OpenEnvCustomerSupport(**kwargs)


# Register with OpenEnv if available
try:
    import openenv
    openenv.register(
        id=ENV_ID,
        entry_point="layer2.openenv_wrapper:make_env",
        kwargs={},
    )
except (ImportError, AttributeError):
    pass  # OpenEnv not installed; wrapper still works standalone