File size: 5,588 Bytes
e6b0e2f
 
 
21da591
e6b0e2f
 
3dc48b7
e6b0e2f
 
 
 
21da591
 
 
b259333
 
 
e6b0e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
21da591
 
 
 
 
 
 
 
 
 
 
 
 
 
e6b0e2f
 
 
 
b259333
 
 
 
e6b0e2f
21da591
e6b0e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b259333
e6b0e2f
21da591
e6b0e2f
 
 
 
 
 
 
21da591
e6b0e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21da591
 
 
 
 
 
 
 
e6b0e2f
21da591
e6b0e2f
 
21da591
 
e6b0e2f
21da591
 
e6b0e2f
 
21da591
e6b0e2f
 
 
 
b259333
 
21da591
 
b259333
21da591
 
b259333
21da591
 
b259333
21da591
 
b259333
21da591
 
b259333
21da591
 
b259333
21da591
 
b259333
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Tests for Layer 2 conversation environment."""

import json
import os
import pytest

from layer0.reward import reward_fn
from layer2.customer_sim import CustomerPersona, CustomerSimulator
from layer2.environment import ConversationEnvironment, EnvConfig


requires_hf_token = pytest.mark.skipif(
    not os.environ.get("HF_TOKEN"),
    reason="HF_TOKEN required for LLM-based tests",
)


def make_persona(**kwargs) -> CustomerPersona:
    defaults = {
        "id": 0,
        "true_intent": "check_balance",
        "personality": "polite",
        "social_engineering": "none",
        "complexity": "simple",
        "description": "Wants to check balance.",
        "first_message": "Hi, I'd like to check my balance.",
    }
    defaults.update(kwargs)
    return CustomerPersona(**defaults)


def _instant_classifier(system_prompt, messages, obs):
    """Test agent that immediately classifies based on keywords."""
    customer_msg = obs.get("customer_message", "").lower()
    keyword_map = {
        "transfer": ["transfer", "send", "move", "wire"],
        "check_balance": ["balance", "check", "how much"],
        "block_card": ["block", "lost", "stolen", "freeze", "card", "missing"],
    }
    for intent, keywords in keyword_map.items():
        if any(kw in customer_msg for kw in keywords):
            return json.dumps({"intent": intent})
    return json.dumps({"intent": "check_balance"})


@pytest.fixture
def env():
    personas = [
        make_persona(id=0, true_intent="check_balance"),
        make_persona(id=1, true_intent="transfer",
                     first_message="I need to send money."),
        make_persona(id=2, true_intent="block_card",
                     first_message="I lost my card."),
    ]
    simulator = CustomerSimulator()
    return ConversationEnvironment(personas=personas, simulator=simulator)


class TestEnvironmentReset:
    def test_reset_returns_observation(self, env):
        obs = env.reset()
        assert "customer_message" in obs
        assert "domain" in obs
        assert "intents" in obs
        assert obs["domain"] == "banking"

    def test_reset_with_specific_persona(self, env):
        persona = make_persona(true_intent="transfer", first_message="I need to send money.")
        obs = env.reset(persona=persona)
        assert obs["customer_message"] == "I need to send money."


class TestEnvironmentStep:
    def test_correct_classification_ends_episode(self, env):
        persona = make_persona(true_intent="check_balance")
        env.reset(persona=persona)
        result = env.step('{"intent": "check_balance"}')
        assert result.done is True
        assert result.reward > 0
        assert result.info["termination_reason"] == "intent_classified"

    def test_wrong_classification_still_ends(self, env):
        persona = make_persona(true_intent="transfer")
        env.reset(persona=persona)
        result = env.step('{"intent": "block_card"}')
        assert result.done is True
        assert result.reward < 0

    @requires_hf_token
    def test_conversation_continues_without_json(self, env):
        env.reset()
        result = env.step("How can I help you today?")
        assert result.done is False
        assert result.reward == 0.0
        assert "customer_message" in result.observation

    @requires_hf_token
    def test_max_turns_terminates(self):
        persona = make_persona()
        simulator = CustomerSimulator()
        env = ConversationEnvironment(
            personas=[persona],
            simulator=simulator,
            config=EnvConfig(max_turns=2),
        )
        env.reset(persona=persona)
        env.step("Hello!")
        result = env.step("How can I help?")
        assert result.done is True
        assert result.info["termination_reason"] == "max_turns_exceeded"


class TestRunEpisode:
    def test_instant_classifier_completes_episode(self, env):
        persona = make_persona(true_intent="check_balance")
        log = env.run_episode(
            system_prompt="test",
            agent_fn=_instant_classifier,
            persona=persona,
        )
        assert log.turns == 1
        assert log.intent_captured is True
        assert log.intent_correct is True

    def test_custom_agent_fn(self, env):
        def always_transfer(system_prompt, messages, obs):
            return '{"intent": "transfer"}'

        persona = make_persona(true_intent="transfer",
                               first_message="I need to send money.")
        log = env.run_episode(
            system_prompt="test",
            agent_fn=always_transfer,
            persona=persona,
        )
        assert log.turns == 1
        assert log.intent_correct is True


class TestRewardDifferentiation:
    """Tests that correct vs incorrect classification produces different rewards."""

    def test_correct_classification_higher_reward(self, env):
        persona = make_persona(true_intent="check_balance")

        def correct_agent(system_prompt, messages, obs):
            return '{"intent": "check_balance"}'

        def wrong_agent(system_prompt, messages, obs):
            return '{"intent": "transfer"}'

        correct_log = env.run_episode(system_prompt="test", agent_fn=correct_agent, persona=persona)
        wrong_log = env.run_episode(system_prompt="test", agent_fn=wrong_agent, persona=persona)

        correct_reward = reward_fn(correct_log)
        wrong_reward = reward_fn(wrong_log)

        assert correct_reward > wrong_reward, (
            f"Correct ({correct_reward:.1f}) should beat wrong ({wrong_reward:.1f})"
        )