Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import random | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from src.executive_assistant.agent import ActionCatalog, BaselineAgent | |
| from src.executive_assistant.env import ExecutiveAssistantEnv | |
| from src.executive_assistant.models import AssistantAction, PolicyDecision, WorkspaceObservation | |
| from src.executive_assistant.runner import EpisodeRunner, EpisodeTrace | |
| ACTION_NAMES = [ | |
| "read_first_unread", | |
| "archive_first_unread", | |
| "forward_client_to_manager", | |
| "reply_meeting_time", | |
| "add_deadline_todo", | |
| "archive_current_email", | |
| "search_q3_architecture", | |
| "reply_with_metrics", | |
| ] | |
| def _current_email_sender(observation: WorkspaceObservation) -> str: | |
| return observation.current_email.sender if observation.current_email else "none" | |
| def encode_observation(task_name: str, observation: WorkspaceObservation) -> str: | |
| unread_senders = ",".join(sorted(email.sender for email in observation.unread_emails)) or "none" | |
| return "|".join( | |
| [ | |
| task_name, | |
| f"unread={len(observation.unread_emails)}", | |
| f"senders={unread_senders}", | |
| f"todos={len(observation.active_todos)}", | |
| f"current={_current_email_sender(observation)}", | |
| f"search={int(bool(observation.search_results))}", | |
| f"history={'/'.join(observation.action_history[-3:]) or 'none'}", | |
| ] | |
| ) | |
| def valid_action_names(task_name: str, observation: WorkspaceObservation) -> list[str]: | |
| valid: list[str] = [] | |
| if task_name == "easy_deadline_extraction": | |
| if observation.current_email is None and observation.unread_emails: | |
| valid.append("read_first_unread") | |
| if observation.current_email is not None: | |
| body = observation.current_email.body.lower() | |
| existing = {todo.lower() for todo in observation.active_todos} | |
| missing_todo = False | |
| if "proposal due" in body and "proposal due" not in existing: | |
| valid.append("add_deadline_todo") | |
| missing_todo = True | |
| elif "prototype due" in body and "prototype due" not in existing: | |
| valid.append("add_deadline_todo") | |
| missing_todo = True | |
| elif "final report due" in body and "final report due" not in existing: | |
| valid.append("add_deadline_todo") | |
| missing_todo = True | |
| if not missing_todo: | |
| valid.append("archive_current_email") | |
| elif task_name == "medium_triage_and_negotiation": | |
| newsletter_senders = { | |
| "news@updates.example", | |
| "promotions@vendor.example", | |
| "events@community.example", | |
| } | |
| if any(email.sender in newsletter_senders for email in observation.unread_emails): | |
| valid.append("archive_first_unread") | |
| if any(email.sender == "client@company.com" for email in observation.unread_emails): | |
| valid.append("forward_client_to_manager") | |
| if any(email.sender == "teammate@company.com" for email in observation.unread_emails): | |
| valid.append("reply_meeting_time") | |
| elif task_name == "hard_rag_reply": | |
| if observation.current_email is None and observation.unread_emails: | |
| valid.append("read_first_unread") | |
| if observation.current_email is not None and not observation.search_results: | |
| valid.append("search_q3_architecture") | |
| if observation.current_email is not None and observation.search_results: | |
| valid.append("reply_with_metrics") | |
| return valid or ACTION_NAMES.copy() | |
| def make_action(action_name: str, observation: WorkspaceObservation) -> AssistantAction: | |
| if action_name == "read_first_unread": | |
| if observation.unread_emails: | |
| return AssistantAction(action_type="read_email", target_id=observation.unread_emails[0].id) | |
| elif action_name == "archive_first_unread": | |
| if observation.unread_emails: | |
| return AssistantAction(action_type="archive", target_id=observation.unread_emails[0].id) | |
| elif action_name == "forward_client_to_manager": | |
| for email in observation.unread_emails: | |
| if email.sender == "client@company.com": | |
| return AssistantAction( | |
| action_type="forward", | |
| target_id=email.id, | |
| secondary_payload="manager@company.com", | |
| payload="Urgent client complaint. Please take over immediately.", | |
| ) | |
| elif action_name == "reply_meeting_time": | |
| target_id = observation.current_email.id if observation.current_email else None | |
| if target_id is None: | |
| for email in observation.unread_emails: | |
| if email.sender == "teammate@company.com": | |
| target_id = email.id | |
| break | |
| if target_id is not None: | |
| return AssistantAction( | |
| action_type="reply", | |
| target_id=target_id, | |
| payload="Hello, 3:30 PM IST works for me. Regards, Executive Assistant", | |
| ) | |
| elif action_name == "add_deadline_todo": | |
| if observation.current_email: | |
| body = observation.current_email.body.lower() | |
| candidates = [ | |
| ("Proposal Due", "2026-04-10", "proposal due"), | |
| ("Prototype Due", "2026-04-20", "prototype due"), | |
| ("Final Report Due", "2026-04-30", "final report due"), | |
| ] | |
| existing = {todo.lower() for todo in observation.active_todos} | |
| for task_name, deadline, marker in candidates: | |
| if marker in body and task_name.lower() not in existing: | |
| return AssistantAction( | |
| action_type="add_todo", | |
| payload=task_name, | |
| secondary_payload=deadline, | |
| ) | |
| elif action_name == "archive_current_email": | |
| if observation.current_email: | |
| return AssistantAction(action_type="archive", target_id=observation.current_email.id) | |
| elif action_name == "search_q3_architecture": | |
| return AssistantAction(action_type="search_files", payload="Q3 Architecture") | |
| elif action_name == "reply_with_metrics": | |
| if observation.current_email and observation.search_results: | |
| snippet = observation.search_results[0].snippet | |
| availability = "99.95%" if "99.95%" in snippet else "unknown" | |
| latency = "182ms" if "182ms" in snippet else "unknown" | |
| cost = "14%" if "14%" in snippet else "unknown" | |
| return AssistantAction( | |
| action_type="reply", | |
| target_id=observation.current_email.id, | |
| payload=( | |
| "Hello,\n" | |
| f"Here are the requested Q3 architecture metrics: availability {availability}, " | |
| f"mean API latency {latency}, and infrastructure cost reduction {cost}.\n" | |
| "Regards,\nExecutive Assistant" | |
| ), | |
| ) | |
| return AssistantAction(action_type="search_files") | |
| class QLearningPolicy: | |
| epsilon: float = 0.2 | |
| alpha: float = 0.3 | |
| gamma: float = 0.95 | |
| seed: int = 7 | |
| def __post_init__(self) -> None: | |
| self.q_values: dict[str, dict[str, float]] = defaultdict( | |
| lambda: {action_name: 0.0 for action_name in ACTION_NAMES} | |
| ) | |
| self.random = random.Random(self.seed) | |
| def choose_action(self, task_name: str, observation: WorkspaceObservation) -> PolicyDecision: | |
| state = encode_observation(task_name, observation) | |
| candidates = valid_action_names(task_name, observation) | |
| if self.random.random() < self.epsilon: | |
| action_name = self.random.choice(candidates) | |
| return PolicyDecision( | |
| reasoning=f"Exploring action template {action_name}.", | |
| action=make_action(action_name, observation), | |
| ) | |
| action_name = max(candidates, key=lambda name: self.q_values[state][name]) | |
| return PolicyDecision( | |
| reasoning=f"Selecting greedy action template {action_name}.", | |
| action=make_action(action_name, observation), | |
| ) | |
| def update( | |
| self, | |
| state: str, | |
| action_name: str, | |
| reward: float, | |
| next_state: str, | |
| done: bool, | |
| ) -> None: | |
| next_best = 0.0 if done else max(self.q_values[next_state].values()) | |
| current = self.q_values[state][action_name] | |
| target = reward + self.gamma * next_best | |
| self.q_values[state][action_name] = current + self.alpha * (target - current) | |
| def save(self, path: str | Path) -> Path: | |
| output = Path(path) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "metadata": { | |
| "action_names": ACTION_NAMES, | |
| "seed": self.seed, | |
| "alpha": self.alpha, | |
| "gamma": self.gamma, | |
| "epsilon": 0.0, | |
| }, | |
| "q_values": self.q_values, | |
| } | |
| output.write_text(json.dumps(payload, indent=2)) | |
| return output | |
| def load(cls, path: str | Path) -> "QLearningPolicy": | |
| checkpoint_path = Path(path) | |
| policy = cls(epsilon=0.0) | |
| raw_payload = json.loads(checkpoint_path.read_text()) | |
| raw_values = raw_payload["q_values"] if "q_values" in raw_payload else raw_payload | |
| policy.q_values = defaultdict( | |
| lambda: {action_name: 0.0 for action_name in ACTION_NAMES} | |
| ) | |
| for state, action_map in raw_values.items(): | |
| policy.q_values[state] = { | |
| action_name: float(action_map.get(action_name, 0.0)) | |
| for action_name in ACTION_NAMES | |
| } | |
| policy.epsilon = 0.0 | |
| return policy | |
| def action_name_from_decision(decision: PolicyDecision, observation: WorkspaceObservation) -> str: | |
| for action_name in ACTION_NAMES: | |
| candidate = make_action(action_name, observation) | |
| if candidate == decision.action: | |
| return action_name | |
| return "search_q3_architecture" | |
| def warm_start_from_teacher( | |
| learner: QLearningPolicy, | |
| teacher: BaselineAgent, | |
| task_names: list[str], | |
| episodes_per_task: int = 4, | |
| ) -> None: | |
| runner = EpisodeRunner(policy=teacher) | |
| for _ in range(episodes_per_task): | |
| for task_name in task_names: | |
| trace = runner.run(task_name) | |
| for index, step in enumerate(trace.steps): | |
| current_observation = WorkspaceObservation.model_validate(step.observation) | |
| previous_observation = ( | |
| WorkspaceObservation.model_validate(trace.steps[index - 1].observation) | |
| if index > 0 | |
| else None | |
| ) | |
| observation = previous_observation or current_observation | |
| state = encode_observation(task_name, observation) | |
| next_state = encode_observation(task_name, current_observation) | |
| reward_delta = step.reward["total_score"] | |
| action_name = action_name_from_decision( | |
| PolicyDecision( | |
| reasoning=step.reasoning, | |
| action=AssistantAction.model_validate(step.action), | |
| ), | |
| observation, | |
| ) | |
| learner.update( | |
| state=state, | |
| action_name=action_name, | |
| reward=reward_delta, | |
| next_state=next_state, | |
| done=bool(step.reward["is_done"]), | |
| ) | |
| def train_q_learning( | |
| episodes: int = 200, | |
| epsilon: float = 0.15, | |
| teacher: BaselineAgent | None = None, | |
| ) -> tuple[QLearningPolicy, dict[str, float]]: | |
| learner = QLearningPolicy(epsilon=epsilon) | |
| task_names = [ | |
| "easy_deadline_extraction", | |
| "medium_triage_and_negotiation", | |
| "hard_rag_reply", | |
| ] | |
| if teacher is not None: | |
| warm_start_from_teacher(learner, teacher, task_names) | |
| scores: dict[str, float] = {} | |
| for episode in range(episodes): | |
| task_name = task_names[episode % len(task_names)] | |
| env = ExecutiveAssistantEnv(task_name=task_name) | |
| observation = env.reset() | |
| previous_total_score = 0.0 | |
| while True: | |
| state = encode_observation(task_name, observation) | |
| decision = learner.choose_action(task_name, observation) | |
| action_name = action_name_from_decision(decision, observation) | |
| next_observation, reward, _, _ = env.step(decision.action) | |
| next_state = encode_observation(task_name, next_observation) | |
| reward_delta = reward.total_score - previous_total_score - 0.01 | |
| previous_total_score = reward.total_score | |
| learner.update( | |
| state=state, | |
| action_name=action_name, | |
| reward=reward_delta, | |
| next_state=next_state, | |
| done=reward.is_done, | |
| ) | |
| observation = next_observation | |
| if reward.is_done: | |
| scores[task_name] = reward.total_score | |
| break | |
| return learner, scores | |
| def evaluate_q_policy(policy: QLearningPolicy) -> dict[str, float]: | |
| original_epsilon = policy.epsilon | |
| policy.epsilon = 0.0 | |
| try: | |
| traces = { | |
| task_name: EpisodeRunner(policy=policy).run(task_name) | |
| for task_name in [ | |
| "easy_deadline_extraction", | |
| "medium_triage_and_negotiation", | |
| "hard_rag_reply", | |
| ] | |
| } | |
| finally: | |
| policy.epsilon = original_epsilon | |
| return {task_name: trace.final_score for task_name, trace in traces.items()} | |
| def default_checkpoint_path(checkpoint_dir: str | Path, checkpoint_name: str) -> Path: | |
| return Path(checkpoint_dir) / checkpoint_name | |