Flickinshots's picture
Deploy Project Epsilon Space bundle
200a73b verified
from __future__ import annotations
import json
import random
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from src.executive_assistant.agent import ActionCatalog, BaselineAgent
from src.executive_assistant.env import ExecutiveAssistantEnv
from src.executive_assistant.models import AssistantAction, PolicyDecision, WorkspaceObservation
from src.executive_assistant.runner import EpisodeRunner, EpisodeTrace
ACTION_NAMES = [
"read_first_unread",
"archive_first_unread",
"forward_client_to_manager",
"reply_meeting_time",
"add_deadline_todo",
"archive_current_email",
"search_q3_architecture",
"reply_with_metrics",
]
def _current_email_sender(observation: WorkspaceObservation) -> str:
return observation.current_email.sender if observation.current_email else "none"
def encode_observation(task_name: str, observation: WorkspaceObservation) -> str:
unread_senders = ",".join(sorted(email.sender for email in observation.unread_emails)) or "none"
return "|".join(
[
task_name,
f"unread={len(observation.unread_emails)}",
f"senders={unread_senders}",
f"todos={len(observation.active_todos)}",
f"current={_current_email_sender(observation)}",
f"search={int(bool(observation.search_results))}",
f"history={'/'.join(observation.action_history[-3:]) or 'none'}",
]
)
def valid_action_names(task_name: str, observation: WorkspaceObservation) -> list[str]:
valid: list[str] = []
if task_name == "easy_deadline_extraction":
if observation.current_email is None and observation.unread_emails:
valid.append("read_first_unread")
if observation.current_email is not None:
body = observation.current_email.body.lower()
existing = {todo.lower() for todo in observation.active_todos}
missing_todo = False
if "proposal due" in body and "proposal due" not in existing:
valid.append("add_deadline_todo")
missing_todo = True
elif "prototype due" in body and "prototype due" not in existing:
valid.append("add_deadline_todo")
missing_todo = True
elif "final report due" in body and "final report due" not in existing:
valid.append("add_deadline_todo")
missing_todo = True
if not missing_todo:
valid.append("archive_current_email")
elif task_name == "medium_triage_and_negotiation":
newsletter_senders = {
"news@updates.example",
"promotions@vendor.example",
"events@community.example",
}
if any(email.sender in newsletter_senders for email in observation.unread_emails):
valid.append("archive_first_unread")
if any(email.sender == "client@company.com" for email in observation.unread_emails):
valid.append("forward_client_to_manager")
if any(email.sender == "teammate@company.com" for email in observation.unread_emails):
valid.append("reply_meeting_time")
elif task_name == "hard_rag_reply":
if observation.current_email is None and observation.unread_emails:
valid.append("read_first_unread")
if observation.current_email is not None and not observation.search_results:
valid.append("search_q3_architecture")
if observation.current_email is not None and observation.search_results:
valid.append("reply_with_metrics")
return valid or ACTION_NAMES.copy()
def make_action(action_name: str, observation: WorkspaceObservation) -> AssistantAction:
if action_name == "read_first_unread":
if observation.unread_emails:
return AssistantAction(action_type="read_email", target_id=observation.unread_emails[0].id)
elif action_name == "archive_first_unread":
if observation.unread_emails:
return AssistantAction(action_type="archive", target_id=observation.unread_emails[0].id)
elif action_name == "forward_client_to_manager":
for email in observation.unread_emails:
if email.sender == "client@company.com":
return AssistantAction(
action_type="forward",
target_id=email.id,
secondary_payload="manager@company.com",
payload="Urgent client complaint. Please take over immediately.",
)
elif action_name == "reply_meeting_time":
target_id = observation.current_email.id if observation.current_email else None
if target_id is None:
for email in observation.unread_emails:
if email.sender == "teammate@company.com":
target_id = email.id
break
if target_id is not None:
return AssistantAction(
action_type="reply",
target_id=target_id,
payload="Hello, 3:30 PM IST works for me. Regards, Executive Assistant",
)
elif action_name == "add_deadline_todo":
if observation.current_email:
body = observation.current_email.body.lower()
candidates = [
("Proposal Due", "2026-04-10", "proposal due"),
("Prototype Due", "2026-04-20", "prototype due"),
("Final Report Due", "2026-04-30", "final report due"),
]
existing = {todo.lower() for todo in observation.active_todos}
for task_name, deadline, marker in candidates:
if marker in body and task_name.lower() not in existing:
return AssistantAction(
action_type="add_todo",
payload=task_name,
secondary_payload=deadline,
)
elif action_name == "archive_current_email":
if observation.current_email:
return AssistantAction(action_type="archive", target_id=observation.current_email.id)
elif action_name == "search_q3_architecture":
return AssistantAction(action_type="search_files", payload="Q3 Architecture")
elif action_name == "reply_with_metrics":
if observation.current_email and observation.search_results:
snippet = observation.search_results[0].snippet
availability = "99.95%" if "99.95%" in snippet else "unknown"
latency = "182ms" if "182ms" in snippet else "unknown"
cost = "14%" if "14%" in snippet else "unknown"
return AssistantAction(
action_type="reply",
target_id=observation.current_email.id,
payload=(
"Hello,\n"
f"Here are the requested Q3 architecture metrics: availability {availability}, "
f"mean API latency {latency}, and infrastructure cost reduction {cost}.\n"
"Regards,\nExecutive Assistant"
),
)
return AssistantAction(action_type="search_files")
@dataclass
class QLearningPolicy:
epsilon: float = 0.2
alpha: float = 0.3
gamma: float = 0.95
seed: int = 7
def __post_init__(self) -> None:
self.q_values: dict[str, dict[str, float]] = defaultdict(
lambda: {action_name: 0.0 for action_name in ACTION_NAMES}
)
self.random = random.Random(self.seed)
def choose_action(self, task_name: str, observation: WorkspaceObservation) -> PolicyDecision:
state = encode_observation(task_name, observation)
candidates = valid_action_names(task_name, observation)
if self.random.random() < self.epsilon:
action_name = self.random.choice(candidates)
return PolicyDecision(
reasoning=f"Exploring action template {action_name}.",
action=make_action(action_name, observation),
)
action_name = max(candidates, key=lambda name: self.q_values[state][name])
return PolicyDecision(
reasoning=f"Selecting greedy action template {action_name}.",
action=make_action(action_name, observation),
)
def update(
self,
state: str,
action_name: str,
reward: float,
next_state: str,
done: bool,
) -> None:
next_best = 0.0 if done else max(self.q_values[next_state].values())
current = self.q_values[state][action_name]
target = reward + self.gamma * next_best
self.q_values[state][action_name] = current + self.alpha * (target - current)
def save(self, path: str | Path) -> Path:
output = Path(path)
output.parent.mkdir(parents=True, exist_ok=True)
payload = {
"metadata": {
"action_names": ACTION_NAMES,
"seed": self.seed,
"alpha": self.alpha,
"gamma": self.gamma,
"epsilon": 0.0,
},
"q_values": self.q_values,
}
output.write_text(json.dumps(payload, indent=2))
return output
@classmethod
def load(cls, path: str | Path) -> "QLearningPolicy":
checkpoint_path = Path(path)
policy = cls(epsilon=0.0)
raw_payload = json.loads(checkpoint_path.read_text())
raw_values = raw_payload["q_values"] if "q_values" in raw_payload else raw_payload
policy.q_values = defaultdict(
lambda: {action_name: 0.0 for action_name in ACTION_NAMES}
)
for state, action_map in raw_values.items():
policy.q_values[state] = {
action_name: float(action_map.get(action_name, 0.0))
for action_name in ACTION_NAMES
}
policy.epsilon = 0.0
return policy
def action_name_from_decision(decision: PolicyDecision, observation: WorkspaceObservation) -> str:
for action_name in ACTION_NAMES:
candidate = make_action(action_name, observation)
if candidate == decision.action:
return action_name
return "search_q3_architecture"
def warm_start_from_teacher(
learner: QLearningPolicy,
teacher: BaselineAgent,
task_names: list[str],
episodes_per_task: int = 4,
) -> None:
runner = EpisodeRunner(policy=teacher)
for _ in range(episodes_per_task):
for task_name in task_names:
trace = runner.run(task_name)
for index, step in enumerate(trace.steps):
current_observation = WorkspaceObservation.model_validate(step.observation)
previous_observation = (
WorkspaceObservation.model_validate(trace.steps[index - 1].observation)
if index > 0
else None
)
observation = previous_observation or current_observation
state = encode_observation(task_name, observation)
next_state = encode_observation(task_name, current_observation)
reward_delta = step.reward["total_score"]
action_name = action_name_from_decision(
PolicyDecision(
reasoning=step.reasoning,
action=AssistantAction.model_validate(step.action),
),
observation,
)
learner.update(
state=state,
action_name=action_name,
reward=reward_delta,
next_state=next_state,
done=bool(step.reward["is_done"]),
)
def train_q_learning(
episodes: int = 200,
epsilon: float = 0.15,
teacher: BaselineAgent | None = None,
) -> tuple[QLearningPolicy, dict[str, float]]:
learner = QLearningPolicy(epsilon=epsilon)
task_names = [
"easy_deadline_extraction",
"medium_triage_and_negotiation",
"hard_rag_reply",
]
if teacher is not None:
warm_start_from_teacher(learner, teacher, task_names)
scores: dict[str, float] = {}
for episode in range(episodes):
task_name = task_names[episode % len(task_names)]
env = ExecutiveAssistantEnv(task_name=task_name)
observation = env.reset()
previous_total_score = 0.0
while True:
state = encode_observation(task_name, observation)
decision = learner.choose_action(task_name, observation)
action_name = action_name_from_decision(decision, observation)
next_observation, reward, _, _ = env.step(decision.action)
next_state = encode_observation(task_name, next_observation)
reward_delta = reward.total_score - previous_total_score - 0.01
previous_total_score = reward.total_score
learner.update(
state=state,
action_name=action_name,
reward=reward_delta,
next_state=next_state,
done=reward.is_done,
)
observation = next_observation
if reward.is_done:
scores[task_name] = reward.total_score
break
return learner, scores
def evaluate_q_policy(policy: QLearningPolicy) -> dict[str, float]:
original_epsilon = policy.epsilon
policy.epsilon = 0.0
try:
traces = {
task_name: EpisodeRunner(policy=policy).run(task_name)
for task_name in [
"easy_deadline_extraction",
"medium_triage_and_negotiation",
"hard_rag_reply",
]
}
finally:
policy.epsilon = original_epsilon
return {task_name: trace.final_score for task_name, trace in traces.items()}
def default_checkpoint_path(checkpoint_dir: str | Path, checkpoint_name: str) -> Path:
return Path(checkpoint_dir) / checkpoint_name