my-env / server /environment.py
JJaniel
refactor: flatten project structure to root for openenv validate compliance
887ef72
import random
import json
import os
from typing import List, Optional
from openenv.core.env_server import Environment
from models import EmailAction, EmailObservation, EmailState, EmailItem
# Load real-world datasets from JSON files
def load_dataset(filename: str):
path = os.path.join(os.path.dirname(__file__), "..", "datasets", filename)
with open(path, "r") as f:
return json.load(f)
EASY_EMAILS = load_dataset("easy_tasks.json")
MEDIUM_EMAILS = load_dataset("medium_tasks.json")
HARD_EMAILS = load_dataset("hard_tasks.json")
class EmailTriageEnv(Environment[EmailAction, EmailObservation, EmailState]):
def __init__(self):
super().__init__()
self.env_state = EmailState()
def reset(self, task_id: int = 1) -> EmailObservation:
self.env_state.task_id = task_id
source = {1: EASY_EMAILS, 2: MEDIUM_EMAILS, 3: HARD_EMAILS}.get(task_id, EASY_EMAILS)
selected = random.sample(source, min(len(source), self.env_state.max_steps))
self.env_state.emails = [
EmailItem(
subject=e["subject"],
body=e["body"],
true_category=e["category"],
true_priority=e["priority"],
required_info=e.get("info", "")
) for e in selected
]
self.env_state.current_step = 0
self.env_state.score = 0.0
return self._get_obs(f"Task {task_id} started.")
def step(self, action: EmailAction) -> EmailObservation:
current_email = self.env_state.emails[self.env_state.current_step]
# Scoring
cat_match = (action.category_id == current_email.true_category)
prio_match = (action.priority == current_email.true_priority)
# Robust info matching (handle whitespace and case)
provided_info = action.extracted_info.strip().upper()
required_info = current_email.required_info.strip().upper()
info_match = (provided_info == required_info) if self.env_state.task_id == 3 else True
reward = 0.0
if self.env_state.task_id in [1, 2]:
if cat_match: reward += 0.5
if prio_match: reward += 0.5
else: # Task 3
if cat_match: reward += 0.3
if prio_match: reward += 0.3
if info_match: reward += 0.4
self.env_state.score += reward
self.env_state.current_step += 1
done = (self.env_state.current_step >= len(self.env_state.emails))
msg = f"Cat: {'OK' if cat_match else 'ERR'}, Prio: {'OK' if prio_match else 'ERR'}"
if self.env_state.task_id == 3:
msg += f", Info: {'OK' if info_match else 'ERR'}"
if not info_match:
msg += f" (Exp: {required_info})"
if done:
final_score = self.env_state.score / len(self.env_state.emails)
msg += f". DONE. Score: {final_score:.2f}"
return self._get_obs(msg, reward=reward, done=done)
def state(self) -> EmailState:
return self.env_state
def _get_obs(self, msg: str, reward: float = 0.0, done: bool = False) -> EmailObservation:
if self.env_state.current_step < len(self.env_state.emails):
e = self.env_state.emails[self.env_state.current_step]
return EmailObservation(
task_id=self.env_state.task_id,
subject=e.subject,
body=e.body,
current_step=self.env_state.current_step,
total_steps=len(self.env_state.emails),
reward=reward,
done=done,
message=msg
)
else:
return EmailObservation(
task_id=self.env_state.task_id,
subject="DONE", body="DONE",
current_step=self.env_state.current_step,
total_steps=len(self.env_state.emails),
reward=reward, done=done, message=msg
)