| import random |
| from copy import deepcopy |
| from .models import Observation, Action, Reward |
| from .tasks import TASKS, TASK_LIST |
|
|
|
|
| class CustomerSupportEnv: |
|
|
| def __init__(self): |
| self.current_task = None |
| self.state_data = None |
| self.done = False |
| self.step_count = 0 |
| self._classified = False |
| self._replied = False |
| self._escalated = False |
| self._closed = False |
|
|
| def reset(self, task_id=None): |
| if task_id: |
| if task_id not in TASKS: |
| raise ValueError(f"Unknown task '{task_id}'. Pick from: {list(TASKS.keys())}") |
| self.current_task = TASKS[task_id] |
| else: |
| self.current_task = random.choice(TASK_LIST) |
|
|
| self.state_data = deepcopy(self.current_task["input"]) |
| self.done = False |
| self.step_count = 0 |
| self._classified = False |
| self._replied = False |
| self._escalated = False |
| self._closed = False |
|
|
| return Observation(**self.state_data) |
|
|
| def step(self, action: Action): |
| if self.done: |
| raise RuntimeError("Episode done. Call reset() first.") |
|
|
| self.step_count += 1 |
| reward = self._compute_reward(action) |
|
|
| if action.action_type == "close": |
| self.done = True |
| self._closed = True |
|
|
| |
| max_steps = self.current_task.get("max_steps", 10) |
| if self.step_count >= max_steps and not self.done: |
| self.done = True |
| new_score = max(0.001, reward.score - 0.05) |
| reward = Reward( |
| score=new_score, |
| feedback=reward.feedback + " | time limit hit, -0.05", |
| breakdown={**reward.breakdown, "time_penalty": -0.05}, |
| ) |
|
|
| if action.content: |
| self.state_data["history"].append(f"Agent: {action.content}") |
|
|
| info = { |
| "step": self.step_count, |
| "task_id": self.current_task["id"], |
| "classified": self._classified, |
| "replied": self._replied, |
| "escalated": self._escalated, |
| "closed": self._closed, |
| } |
|
|
| return Observation(**self.state_data), reward, self.done, info |
|
|
| def state(self): |
| return self.state_data |
|
|
| def _compute_reward(self, action: Action) -> Reward: |
| correct = self.current_task["expected"] |
| score = 0.0 |
| breakdown = {} |
|
|
| if action.action_type == "classify": |
| if action.category and action.category.lower() == correct["category"].lower(): |
| score += 0.3 |
| breakdown["classify"] = 0.3 |
| else: |
| breakdown["classify"] = 0.0 |
| self._classified = True |
|
|
| elif action.action_type == "reply": |
| if not self._classified: |
| score -= 0.05 |
| breakdown["early_reply_penalty"] = -0.05 |
|
|
| hits = sum(1 for kw in correct["keywords"] if kw in (action.content or "").lower()) |
| reply_score = min(0.4, hits * 0.1) |
| score += reply_score |
| breakdown["reply"] = reply_score |
| self._replied = True |
|
|
| elif action.action_type == "escalate": |
| if correct["requires_escalation"]: |
| score += 0.2 |
| breakdown["escalate"] = 0.2 |
| else: |
| score -= 0.1 |
| breakdown["escalate"] = -0.1 |
| self._escalated = True |
|
|
| elif action.action_type == "close": |
| bonus = 0.0 |
| if self._classified: |
| bonus += 0.1 |
| if self._replied: |
| bonus += 0.1 |
| if correct["requires_escalation"] and self._escalated: |
| bonus += 0.1 |
| score += bonus |
| breakdown["close_bonus"] = bonus |
|
|
| score = round(max(0.001, min(0.999, score)), 4) |
| feedback = self._make_feedback(action, breakdown, correct) |
|
|
| return Reward(score=score, feedback=feedback, breakdown=breakdown) |
|
|
| def _make_feedback(self, action, breakdown, correct): |
| parts = [] |
|
|
| if breakdown.get("classify") == 0.3: |
| parts.append("correct category") |
| elif "classify" in breakdown: |
| parts.append(f"wrong category (expected {correct['category']})") |
|
|
| if "early_reply_penalty" in breakdown: |
| parts.append("replied before classifying") |
|
|
| if "reply" in breakdown: |
| parts.append(f"reply score {breakdown['reply']:.2f}") |
|
|
| if breakdown.get("escalate") == 0.2: |
| parts.append("escalated correctly") |
| elif breakdown.get("escalate") == -0.1: |
| parts.append("unnecessary escalation") |
|
|
| if "close_bonus" in breakdown: |
| parts.append(f"close bonus {breakdown['close_bonus']:.2f}") |
|
|
| return ", ".join(parts) if parts else "ok" |