File size: 5,516 Bytes
af3f703
aa4f7bc
 
 
 
 
 
 
 
 
 
af3f703
aa4f7bc
 
 
 
af3f703
aa4f7bc
 
 
 
 
 
 
af3f703
aa4f7bc
 
 
 
af3f703
aa4f7bc
 
 
 
 
 
 
 
 
 
 
 
 
af3f703
 
aa4f7bc
 
 
 
 
 
 
 
 
 
 
 
 
af3f703
 
1724801
 
aa4f7bc
 
 
 
 
 
af3f703
 
aa4f7bc
 
 
31f4f64
1724801
 
 
 
 
aa4f7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2722e
 
 
 
 
aa4f7bc
ba2722e
1724801
aa4f7bc
ba2722e
aa4f7bc
 
 
1724801
 
aa4f7bc
 
 
af3f703
aa4f7bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Tuple, Dict, Any, Optional, cast
from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
from .tasks import TASKS
from .graders import grade

class SupportTicketEnv:
    def __init__(self, task_id: str = "task_easy_1"):
        self.task_id = task_id
        if task_id not in TASKS:
            raise ValueError(f"Unknown task_id: {task_id}")
        self.task_data = TASKS[task_id]
        self.state: Optional[EnvironmentState] = None
        self.max_steps = 10
        self.reset()
        
    def reset(self) -> Observation:
        ticket_data = cast(Dict[str, Any], self.task_data["ticket"])
        self.state = EnvironmentState(
            current_task_id=self.task_id,
            step_count=0,
            ticket=TicketInfo(**ticket_data),
            action_history=[],
            is_done=False,
            final_reward=0.0,
            task_difficulty=str(self.task_data["difficulty"])
        )
        return self._get_observation("System initialized. Ticket assigned.")
        
    def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
        assert self.state is not None
        return Observation(
            ticket=self.state.ticket,
            available_actions=[
                "fetch_user_data", "check_policy", "issue_refund", 
                "reply_to_customer", "escalate", "close_ticket"
            ],
            system_message=system_message,
            history=[f"{a.action_type}({a.parameters})" for a in self.state.action_history],
            tool_output=tool_output,
            step_count=self.state.step_count
        )
        
    def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
        assert self.state is not None

        if self.state.is_done:
            return self._get_observation("Episode is over."), 0.0, True, {}
            
        self.state.step_count += 1
        self.state.action_history.append(action)
        
        tool_output = None
        system_message = f"Action {action.action_type} executed."
        
        # Execute action logic
        if action.action_type == "fetch_user_data":
            user_id = action.parameters.get("user_id")
            if user_id == self.state.ticket.user_id:
                user_data = cast(Dict[str, Any], self.task_data["user_data"])
                self.state.user_data = UserData(**user_data)
                chargeback_info = f", Chargebacks = {self.state.user_data.chargeback_history}" if hasattr(self.state.user_data, "chargeback_history") else ""
                tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}{chargeback_info}"
            else:
                tool_output = "Error: Invalid user_id."
                system_message = "Failed to fetch user data."
                
        elif action.action_type == "check_policy":
            issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
            policy_map = cast(Dict[str, str], self.task_data["policy"])
            policy = policy_map.get(issue_type, "No specific policy found.")
            tool_output = f"Policy for {issue_type}: {policy}"
            
        elif action.action_type == "issue_refund":
            if self.state.user_data and self.state.user_data.chargeback_history is not None and self.state.user_data.chargeback_history > 0:
                tool_output = "Refund denied due to chargeback history."
                system_message = "Refund action blocked."
            else:
                amount = action.parameters.get("amount", "fully")
                tool_output = f"Refund issued for {amount}."
            
        elif action.action_type == "reply_to_customer":
            msg = action.parameters.get("message", "")
            tool_output = f"Replied: '{msg}'"
            
        elif action.action_type == "escalate":
            reason = action.parameters.get("reason", "support_tier2")
            tool_output = f"Escalated to {reason}."
            self.state.ticket.status = "escalated"
            self.state.is_done = True
            
        elif action.action_type == "close_ticket":
            res = action.parameters.get("resolution", "")
            tool_output = f"Ticket closed. Resolution: {res}"
            self.state.ticket.status = "closed"
            self.state.is_done = True
            
        else:
            tool_output = "Invalid action."
            system_message = "Action unrecognized."
            
        # Check termination
        if self.state.step_count >= self.max_steps:
            self.state.is_done = True
            system_message = "Max steps reached."
            
        # Calculate intermediate/final reward
        new_total_reward = grade(self.state)
        step_reward = new_total_reward - self.state.final_reward
        self.state.final_reward = new_total_reward
        reward = step_reward
        
        if self.state.is_done:
            print(f"Final reward calculated: {self.state.final_reward}")

        info = {
            "current_reward": self.state.final_reward,
            "step_count": self.state.step_count
        }
        
        print(f"Updated info dictionary: {info}")
        
        return self._get_observation(system_message, tool_output), reward, self.state.is_done, info

    def get_state(self) -> EnvironmentState:
        assert self.state is not None
        return self.state