customer-support-openenv / env /environment.py
Adit1Sharma's picture
Fix score bounds for validator
bc8710c
import random
from copy import deepcopy
from .models import Observation, Action, Reward
from .tasks import TASKS, TASK_LIST
class CustomerSupportEnv:
def __init__(self):
self.current_task = None
self.state_data = None
self.done = False
self.step_count = 0
self._classified = False
self._replied = False
self._escalated = False
self._closed = False
def reset(self, task_id=None):
if task_id:
if task_id not in TASKS:
raise ValueError(f"Unknown task '{task_id}'. Pick from: {list(TASKS.keys())}")
self.current_task = TASKS[task_id]
else:
self.current_task = random.choice(TASK_LIST)
self.state_data = deepcopy(self.current_task["input"])
self.done = False
self.step_count = 0
self._classified = False
self._replied = False
self._escalated = False
self._closed = False
return Observation(**self.state_data)
def step(self, action: Action):
if self.done:
raise RuntimeError("Episode done. Call reset() first.")
self.step_count += 1
reward = self._compute_reward(action)
if action.action_type == "close":
self.done = True
self._closed = True
# hit max steps → small penalty
max_steps = self.current_task.get("max_steps", 10)
if self.step_count >= max_steps and not self.done:
self.done = True
new_score = max(0.001, reward.score - 0.05)
reward = Reward(
score=new_score,
feedback=reward.feedback + " | time limit hit, -0.05",
breakdown={**reward.breakdown, "time_penalty": -0.05},
)
if action.content:
self.state_data["history"].append(f"Agent: {action.content}")
info = {
"step": self.step_count,
"task_id": self.current_task["id"],
"classified": self._classified,
"replied": self._replied,
"escalated": self._escalated,
"closed": self._closed,
}
return Observation(**self.state_data), reward, self.done, info
def state(self):
return self.state_data
def _compute_reward(self, action: Action) -> Reward:
correct = self.current_task["expected"]
score = 0.0
breakdown = {}
if action.action_type == "classify":
if action.category and action.category.lower() == correct["category"].lower():
score += 0.3
breakdown["classify"] = 0.3
else:
breakdown["classify"] = 0.0
self._classified = True
elif action.action_type == "reply":
if not self._classified:
score -= 0.05
breakdown["early_reply_penalty"] = -0.05
hits = sum(1 for kw in correct["keywords"] if kw in (action.content or "").lower())
reply_score = min(0.4, hits * 0.1)
score += reply_score
breakdown["reply"] = reply_score
self._replied = True
elif action.action_type == "escalate":
if correct["requires_escalation"]:
score += 0.2
breakdown["escalate"] = 0.2
else:
score -= 0.1
breakdown["escalate"] = -0.1
self._escalated = True
elif action.action_type == "close":
bonus = 0.0
if self._classified:
bonus += 0.1
if self._replied:
bonus += 0.1
if correct["requires_escalation"] and self._escalated:
bonus += 0.1
score += bonus
breakdown["close_bonus"] = bonus
score = round(max(0.001, min(0.999, score)), 4)
feedback = self._make_feedback(action, breakdown, correct)
return Reward(score=score, feedback=feedback, breakdown=breakdown)
def _make_feedback(self, action, breakdown, correct):
parts = []
if breakdown.get("classify") == 0.3:
parts.append("correct category")
elif "classify" in breakdown:
parts.append(f"wrong category (expected {correct['category']})")
if "early_reply_penalty" in breakdown:
parts.append("replied before classifying")
if "reply" in breakdown:
parts.append(f"reply score {breakdown['reply']:.2f}")
if breakdown.get("escalate") == 0.2:
parts.append("escalated correctly")
elif breakdown.get("escalate") == -0.1:
parts.append("unnecessary escalation")
if "close_bonus" in breakdown:
parts.append(f"close bonus {breakdown['close_bonus']:.2f}")
return ", ".join(parts) if parts else "ok"