supermail / server /environment.py
vicky1428's picture
Upload folder using huggingface_hub
b9ad6f9 verified
"""Supermail OpenEnv environment implementation."""
from __future__ import annotations
import json
from dataclasses import dataclass
from uuid import uuid4
try:
from openenv.core.env_server.interfaces import Environment
except ImportError: # pragma: no cover - local fallback when OpenEnv is absent
class Environment:
"""Fallback OpenEnv Environment base class."""
try:
from ..models import SupportAction, SupportObservation, SupportState
from ..tasks import ALL_TASKS, FIELD_OPTIONS, TASKS_BY_ID, TaskDefinition
except ImportError: # pragma: no cover
from models import SupportAction, SupportObservation, SupportState
from tasks import ALL_TASKS, FIELD_OPTIONS, TASKS_BY_ID, TaskDefinition
@dataclass(frozen=True)
class StepAssessment:
"""Internal grading result for one agent action."""
reward: float
score: float
done: bool
success: bool
feedback: str
error: str | None
matched_fields: set[str]
class SupermailEnvironment(Environment):
"""Deterministic customer support email triage environment."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
MIN_SCORE: float = 0.01
MAX_SCORE: float = 0.99
def __init__(self, task_id: str | None = None):
self._requested_task_id = task_id
self._task_order = [task.task_id for task in ALL_TASKS]
self._next_task_index = 0
self._task: TaskDefinition | None = None
self._matched_fields: set[str] = set()
self._history: list[str] = []
self._score = self._bounded_score(0.0)
self._state = SupportState(
episode_id=str(uuid4()),
step_count=0,
score=self._score,
)
@property
def benchmark(self) -> str:
return "supermail"
@property
def task_name(self) -> str:
if self._task is not None:
return self._task.task_id
if self._requested_task_id:
return self._requested_task_id
return self._task_order[self._next_task_index % len(self._task_order)]
def reset(self) -> SupportObservation:
"""Start a fresh episode."""
self._task = self._select_task()
self._matched_fields = set()
self._history = []
self._score = self._bounded_score(0.0)
self._state = SupportState(
episode_id=str(uuid4()),
step_count=0,
task_id=self._task.task_id,
difficulty=self._task.difficulty,
score=self._score,
matched_fields=[],
attempts_remaining=self._task.max_attempts,
)
return self._build_observation(
feedback=(
f"{self._task.guidance} Required fields: "
f"{', '.join(self._task.required_fields)}."
),
reward=0.0,
done=False,
last_action_error=None,
success=False,
)
def step(self, action: SupportAction) -> SupportObservation: # type: ignore[override]
"""Grade one classification attempt and return the next observation."""
if self._task is None:
raise RuntimeError("Call reset() before step().")
self._state.step_count += 1
decision = self._extract_decision(action)
assessment = self._assess(decision)
self._matched_fields = assessment.matched_fields
self._score = assessment.score
self._state.score = assessment.score
self._state.matched_fields = sorted(self._matched_fields)
self._state.attempts_remaining = max(
self._task.max_attempts - self._state.step_count,
0,
)
compact_decision = json.dumps(decision, sort_keys=True)
self._history.append(
"step="
f"{self._state.step_count} decision={compact_decision} "
f"reward={assessment.reward:.2f} score={assessment.score:.2f} "
f"feedback={assessment.feedback}"
)
return self._build_observation(
feedback=assessment.feedback,
reward=assessment.reward,
done=assessment.done,
last_action_error=assessment.error,
success=assessment.success,
)
@property
def state(self) -> SupportState:
"""Return the current environment state."""
return self._state
def close(self) -> None:
"""No-op close hook for API symmetry."""
def _select_task(self) -> TaskDefinition:
if self._requested_task_id:
return TASKS_BY_ID[self._requested_task_id]
task_id = self._task_order[self._next_task_index % len(self._task_order)]
self._next_task_index += 1
return TASKS_BY_ID[task_id]
def _extract_decision(self, action: SupportAction) -> dict[str, str]:
decision: dict[str, str] = {}
for field_name in ("priority", "category", "action"):
value = getattr(action, field_name, None)
if value:
decision[field_name] = value
return decision
def _bounded_score(self, raw_score: float) -> float:
"""Map raw progress into the open interval (0, 1)."""
clamped_raw_score = min(max(raw_score, 0.0), 1.0)
scaled_score = self.MIN_SCORE + (
clamped_raw_score * (self.MAX_SCORE - self.MIN_SCORE)
)
return round(scaled_score, 2)
def _assess(self, decision: dict[str, str]) -> StepAssessment:
if self._task is None:
raise RuntimeError("Task not initialized.")
if not decision:
return StepAssessment(
reward=-0.10,
score=round(self._score, 2),
done=self._state.step_count >= self._task.max_attempts,
success=False,
feedback=(
"No decision fields were submitted. Provide "
+ ", ".join(self._task.required_fields)
+ "."
),
error="empty_action",
matched_fields=set(self._matched_fields),
)
matched_fields = set(self._matched_fields)
newly_matched: list[str] = []
mismatched_fields: list[str] = []
for field_name in self._task.required_fields:
predicted = decision.get(field_name)
if predicted is None:
continue
if predicted == self._task.expected[field_name]:
if field_name not in matched_fields:
newly_matched.append(field_name)
matched_fields.add(field_name)
else:
mismatched_fields.append(field_name)
reward = sum(self._task.field_weights[field] for field in newly_matched)
if mismatched_fields and not newly_matched:
reward -= 0.10
elif not newly_matched and not mismatched_fields:
reward -= 0.02
if self._state.step_count > 3 and matched_fields != set(self._task.required_fields):
reward -= 0.05
raw_score = sum(self._task.field_weights[field] for field in matched_fields)
score = self._bounded_score(raw_score)
success = matched_fields == set(self._task.required_fields)
done = success or self._state.step_count >= self._task.max_attempts
feedback_parts: list[str] = []
if newly_matched:
feedback_parts.append("Matched " + ", ".join(newly_matched) + ".")
if mismatched_fields:
feedback_parts.append("Incorrect " + ", ".join(mismatched_fields) + ".")
remaining_fields = [
field for field in self._task.required_fields if field not in matched_fields
]
if success:
feedback_parts.append("All required fields are correct.")
elif remaining_fields:
feedback_parts.append("Still need " + ", ".join(remaining_fields) + ".")
if done and not success:
feedback_parts.append("Max attempts reached.")
if not feedback_parts:
feedback_parts.append("No new progress.")
return StepAssessment(
reward=round(reward, 2),
score=score,
done=done,
success=success,
feedback=" ".join(feedback_parts),
error=None,
matched_fields=matched_fields,
)
def _build_observation(
self,
*,
feedback: str,
reward: float,
done: bool,
last_action_error: str | None,
success: bool,
) -> SupportObservation:
if self._task is None:
raise RuntimeError("Task not initialized.")
required_allowed_values = {
field_name: FIELD_OPTIONS[field_name]
for field_name in self._task.required_fields
}
return SupportObservation(
task_id=self._task.task_id,
task_type=self._task.difficulty,
benchmark=self._task.benchmark,
objective=self._task.objective,
email=self._task.email,
context=dict(self._task.context),
required_fields=list(self._task.required_fields),
allowed_values=required_allowed_values,
history=list(self._history),
feedback=feedback,
score=round(self._score, 2),
attempts_remaining=max(
self._task.max_attempts - self._state.step_count,
0,
),
done=done,
reward=round(reward, 2),
metadata={
"last_action_error": last_action_error,
"success": success,
"score": round(self._score, 2),
"matched_fields": sorted(self._matched_fields),
},
)
SupportSimEnvironment = SupermailEnvironment