| import random |
| import json |
| import os |
| from typing import List, Optional |
| from openenv.core.env_server import Environment |
| from models import EmailAction, EmailObservation, EmailState, EmailItem |
|
|
| |
| def load_dataset(filename: str): |
| path = os.path.join(os.path.dirname(__file__), "..", "datasets", filename) |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
| EASY_EMAILS = load_dataset("easy_tasks.json") |
| MEDIUM_EMAILS = load_dataset("medium_tasks.json") |
| HARD_EMAILS = load_dataset("hard_tasks.json") |
|
|
| class EmailTriageEnv(Environment[EmailAction, EmailObservation, EmailState]): |
| def __init__(self): |
| super().__init__() |
| self.env_state = EmailState() |
|
|
| def reset(self, task_id: int = 1) -> EmailObservation: |
| self.env_state.task_id = task_id |
| source = {1: EASY_EMAILS, 2: MEDIUM_EMAILS, 3: HARD_EMAILS}.get(task_id, EASY_EMAILS) |
| |
| selected = random.sample(source, min(len(source), self.env_state.max_steps)) |
| self.env_state.emails = [ |
| EmailItem( |
| subject=e["subject"], |
| body=e["body"], |
| true_category=e["category"], |
| true_priority=e["priority"], |
| required_info=e.get("info", "") |
| ) for e in selected |
| ] |
| self.env_state.current_step = 0 |
| self.env_state.score = 0.0 |
| |
| return self._get_obs(f"Task {task_id} started.") |
|
|
| def step(self, action: EmailAction) -> EmailObservation: |
| current_email = self.env_state.emails[self.env_state.current_step] |
| |
| |
| cat_match = (action.category_id == current_email.true_category) |
| prio_match = (action.priority == current_email.true_priority) |
| |
| |
| provided_info = action.extracted_info.strip().upper() |
| required_info = current_email.required_info.strip().upper() |
| info_match = (provided_info == required_info) if self.env_state.task_id == 3 else True |
| |
| reward = 0.0 |
| if self.env_state.task_id in [1, 2]: |
| if cat_match: reward += 0.5 |
| if prio_match: reward += 0.5 |
| else: |
| if cat_match: reward += 0.3 |
| if prio_match: reward += 0.3 |
| if info_match: reward += 0.4 |
| |
| self.env_state.score += reward |
| self.env_state.current_step += 1 |
| done = (self.env_state.current_step >= len(self.env_state.emails)) |
| |
| msg = f"Cat: {'OK' if cat_match else 'ERR'}, Prio: {'OK' if prio_match else 'ERR'}" |
| if self.env_state.task_id == 3: |
| msg += f", Info: {'OK' if info_match else 'ERR'}" |
| if not info_match: |
| msg += f" (Exp: {required_info})" |
| |
| if done: |
| final_score = self.env_state.score / len(self.env_state.emails) |
| msg += f". DONE. Score: {final_score:.2f}" |
| |
| return self._get_obs(msg, reward=reward, done=done) |
|
|
| def state(self) -> EmailState: |
| return self.env_state |
|
|
| def _get_obs(self, msg: str, reward: float = 0.0, done: bool = False) -> EmailObservation: |
| if self.env_state.current_step < len(self.env_state.emails): |
| e = self.env_state.emails[self.env_state.current_step] |
| return EmailObservation( |
| task_id=self.env_state.task_id, |
| subject=e.subject, |
| body=e.body, |
| current_step=self.env_state.current_step, |
| total_steps=len(self.env_state.emails), |
| reward=reward, |
| done=done, |
| message=msg |
| ) |
| else: |
| return EmailObservation( |
| task_id=self.env_state.task_id, |
| subject="DONE", body="DONE", |
| current_step=self.env_state.current_step, |
| total_steps=len(self.env_state.emails), |
| reward=reward, done=done, message=msg |
| ) |
|
|