OffGridSchedula / server /agent.py
ParetoOptimal's picture
Initial Commit
0366d65
Raw
History Blame Contribute Delete
20.4 kB
"""The scheduling agent: thread (+images) -> validated ActionPlan.
Replaces the old one-shot extractor. The model reasons over a whole conversation
and emits a single constrained ActionPlan: events, conflicts (vs the user's
existing calendar), proposed alternative times, a reply draft, and an optional
clarification question. Output is grammar-constrained so it always parses.
"""
from __future__ import annotations
import json
import os
import re
from datetime import datetime, timedelta
from typing import Optional
from dateutil import parser as dtparser
from pydantic import ValidationError
from . import events, memory
from .schema import ActionPlan, Event
SYSTEM = (
"You are a scheduling assistant reading a chat conversation (text, and sometimes images "
"such as screenshots, invites, or flyers). Decide what calendar action is warranted and "
"return ONLY a JSON object matching the ActionPlan schema:\n"
"- reasoning: one or two sentences of why.\n"
"- events: concrete events with ISO 8601 datetimes; resolve relative dates from the current "
"datetime. Empty if there is no real plan. List EVERY distinct event separately — one thread "
"often holds several (e.g. a drop-off AND a pickup, or two appointments, are separate events).\n"
"- title: a short, self-contained calendar title summarizing the action and subject "
"(e.g. \"Pick up Priya — Terminal 4\", \"Mia — dental cleaning\"), not a quote of the "
"message.\n"
"- location: the venue or address when one is mentioned (join multi-line addresses into one "
"string); null otherwise.\n"
"- end: when a duration is stated (\"Duration: 30–45 min\", \"for 2 hours\", \"runs 90 "
"minutes\"), set end = start + duration, using the LOWER bound of a range; when an end time "
"is stated (\"7-9pm\"), use it; otherwise null. Never guess a duration that was not given.\n"
"- early arrival: if told to arrive N minutes early (\"please arrive 15 minutes early\"), "
"start = the arrival time (stated time minus N); end still counts from the STATED time; put "
"the stated time and the reason in notes.\n"
"- reminder_minutes: a stated lead time always wins (\"remind me 2 hours before\" -> 120); "
"otherwise 60 for doctor/medical visits, 30 for parties, 45 for carpools or school events; "
"for anything else use your judgment.\n"
"- conflicts: for any event that clashes with the provided existing calendar, the event_index, "
"what it clashes with, and severity (overlap|adjacent|tight).\n"
"- proposed_times: ISO 8601 alternatives when there is a conflict.\n"
"- reply_draft: a short, natural reply the user could send back.\n"
"- needs_clarification: a question if the plan is ambiguous, else null. If something should "
"be scheduled but its day or time is not yet known (\"TBD\", \"I'll confirm\", \"sometime "
"soon\"), leave events empty and ASK via needs_clarification instead of guessing.\n"
"Do not invent events that were not discussed."
)
def _existing_block(existing: list[Event]) -> str:
if not existing:
return "Existing calendar: (none provided)"
lines = [f"- {e.title}: {e.start}..{e.end or e.start}" for e in existing]
return "Existing calendar:\n" + "\n".join(lines)
def build_messages(
thread: str,
now: datetime,
existing: list[Event],
images: Optional[list[str]] = None,
memory_block: Optional[str] = None,
) -> list[dict]:
"""Build chat messages. ``images`` are base64 data URIs (used from phase 3).
``memory_block`` is the caller's recall block (per-user/localStorage memory);
when None, fall back to the server-side global memory.recall()."""
mem = memory.recall() if memory_block is None else memory_block
mem_block = f"{mem}\n\n" if mem else ""
text = (
f"Current datetime: {now.strftime('%A')}, {now.isoformat()}\n"
f"{_existing_block(existing)}\n\n"
f"{mem_block}"
f"Conversation:\n{thread}\n\n"
"Return the ActionPlan JSON now."
)
if not images:
return [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": text},
]
# Multimodal content format understood by llama.cpp vision chat handlers.
content = [{"type": "text", "text": text}]
for uri in images:
content.append({"type": "image_url", "image_url": {"url": uri}})
return [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": content},
]
def run_agent(
thread: str,
now: Optional[datetime] = None,
existing: Optional[list[Event]] = None,
images: Optional[list[str]] = None,
memory_block: Optional[str] = None,
) -> ActionPlan:
now = now or datetime.now()
existing = existing or []
with events.run_scope("analyze"):
if images:
events.emit("vision", f"reading {len(images)} image(s)", images=len(images))
if os.environ.get("USE_STUB_EXTRACTOR") == "1":
plan = _stub_plan(thread, now)
else:
from .model import complete_json # lazy: avoids llama.cpp in stub mode
raw = complete_json(
build_messages(thread, now, existing, images, memory_block),
json_schema=ActionPlan.model_json_schema(),
)
plan = apply_text_rules(thread, _polish_titles(thread, _parse_plan(raw)))
# Global path only: with client-owned (per-user) memory, the UI merges
# learned contacts itself (memory.learn_from_plan) so we don't pollute the
# shared server file.
if memory_block is None:
memory.observe_plan(plan) # grows-with-you: learn recurring contacts
events.emit("decision", f"{len(plan.events)} event(s) detected", events=len(plan.events))
return plan
def _parse_plan(raw: str) -> ActionPlan:
try:
return ActionPlan(**json.loads(raw))
except (json.JSONDecodeError, ValidationError):
# Grammar should prevent this; degrade to an empty plan rather than 500.
return ActionPlan(reasoning="Could not parse model output.")
# --------------------------------------------------------------------------- #
# Title polish (optional second pass, TITLE_POLISH=1): rewrite each extracted
# event's title into a calendar-ready action+subject summary. The extraction
# pass already gets a title style instruction; this pass gives the model one
# focused job, which helps on echo-prone inputs (flyers, forwarded notices).
# --------------------------------------------------------------------------- #
TITLE_SYSTEM = (
"You rewrite calendar event titles. Given a conversation and the events extracted from "
"it, return ONLY a JSON object {\"titles\": [...]} with exactly one title per event, in "
"the same order. Each title is a short, self-contained calendar entry summarizing the "
"action and subject (e.g. \"Pick up Priya — Terminal 4\", \"Mia — dental cleaning\"). "
"Keep names and places; drop filler, hype and sender wording. Never add facts that are "
"not in the conversation."
)
TITLES_SCHEMA = {
"type": "object",
"properties": {"titles": {"type": "array", "items": {"type": "string"}}},
"required": ["titles"],
}
def build_title_messages(thread: str, events: list[dict]) -> list[dict]:
"""Messages for the polish pass. ``events`` are Event-shaped dicts."""
lines = [
f"{i + 1}. {e.get('title') or '(untitled)'} @ {e.get('start')}"
+ (f" ({e['location']})" if e.get("location") else "")
for i, e in enumerate(events)
]
text = (
f"Conversation:\n{thread}\n\n"
"Extracted events:\n" + "\n".join(lines) + "\n\n"
"Return the titles JSON now."
)
return [
{"role": "system", "content": TITLE_SYSTEM},
{"role": "user", "content": text},
]
def merge_titles(plan: ActionPlan, raw: str) -> ActionPlan:
"""Apply a polish-pass response onto the plan; on any mismatch keep the
original titles (the polish pass must never be able to lose an event)."""
try:
titles = json.loads(raw).get("titles")
except (json.JSONDecodeError, AttributeError):
return plan
if not isinstance(titles, list) or len(titles) != len(plan.events):
return plan
for ev, title in zip(plan.events, titles):
if isinstance(title, str) and title.strip():
ev.title = title.strip()[:80]
return plan
def apply_text_rules(thread: str, plan: ActionPlan) -> ActionPlan:
"""Deterministic guarantees for explicitly-communicated logistics (same
philosophy as conflict detection: don't leave must-hold rules to the model).
Single-event plans only — multi-event threads keep per-event model judgment.
- "arrive N minutes early" -> start = arrival time, but ONLY when the model
demonstrably did not shift already (its start equals the stated time).
- end = STATED time + stated duration: a self-shifting model often counts
the duration from the arrival time (10:15+30=10:45 instead of 11:00).
- reminder: an explicit stated lead time always wins; else type defaults
(medical 60 / party 30 / carpool-school 45); else the model's judgment.
"""
if len(plan.events) != 1:
return plan
ev = plan.events[0]
early = _EARLY_RE.search(thread)
stated = _find_time(thread)
if early and stated:
try:
start_dt = datetime.fromisoformat(ev.start)
except ValueError:
start_dt = None
if start_dt is not None:
mins = int(early.group(1))
appt_dt = start_dt.replace(hour=stated[0], minute=stated[1])
if start_dt == appt_dt: # model did not shift -> start at arrival
start_dt = appt_dt - timedelta(minutes=mins)
ev.start = start_dt.isoformat()
if start_dt == appt_dt - timedelta(minutes=mins):
# The event covers arrival (we or the model shifted it): anchor
# the END to the stated time + stated duration, and make sure
# the official time survives in the notes.
duration = _find_duration_minutes(thread)
if duration:
ev.end = (appt_dt + timedelta(minutes=duration)).isoformat()
hhmm = appt_dt.strftime("%H:%M")
if hhmm not in (ev.notes or ""):
note = f"Appointment at {hhmm}; arrive {mins} min early"
ev.notes = f"{ev.notes}{note}" if ev.notes else note
m = _REMIND_EXPLICIT_RE.search(thread)
if m:
n = int(m.group(1))
ev.reminder_minutes = n * 60 if m.group(2).lower().startswith("h") else n
elif _MEDICAL_RE.search(thread):
ev.reminder_minutes = 60
elif _PARTY_RE.search(thread):
ev.reminder_minutes = 30
elif _CARPOOL_SCHOOL_RE.search(thread):
ev.reminder_minutes = 45
return plan
def _polish_titles(thread: str, plan: ActionPlan) -> ActionPlan:
if not plan.events or os.environ.get("TITLE_POLISH") != "1":
return plan
from .model import complete_json # lazy: avoids llama.cpp in stub mode
try:
raw = complete_json(
build_title_messages(thread, [e.model_dump() for e in plan.events]),
json_schema=TITLES_SCHEMA,
max_tokens=256,
)
except Exception: # noqa: BLE001 polish is best-effort, never fatal
return plan
return merge_titles(plan, raw)
def run_agent_stream(
thread: str,
now: Optional[datetime] = None,
existing: Optional[list[Event]] = None,
images: Optional[list[str]] = None,
busy=None,
memory_block: Optional[str] = None,
):
"""Generator for the UI: yields (partial_text, plan_or_None). Streams the
model output for a live 'thinking' panel, then yields the final ActionPlan
(with deterministic conflicts annotated if ``busy`` intervals are given).
``memory_block`` carries the caller's per-user (localStorage) memory."""
now = now or datetime.now()
existing = existing or []
with events.run_scope("analyze"):
if images:
events.emit("vision", f"reading {len(images)} image(s)", images=len(images))
if os.environ.get("USE_STUB_EXTRACTOR") == "1":
plan = _stub_plan(thread, now)
text = json.dumps(plan.model_dump(), indent=2)
events.emit("model", "stub inference", latency_ms=0)
acc = ""
for i in range(0, len(text), 24): # simulate token streaming
acc += text[i : i + 24]
yield acc, None
else:
from .model import stream_complete_json
acc = ""
for delta in stream_complete_json(
build_messages(thread, now, existing, images, memory_block),
ActionPlan.model_json_schema(),
):
acc += delta
yield acc, None
plan = apply_text_rules(thread, _polish_titles(thread, _parse_plan(acc)))
# Global path only (see run_agent): client memory is merged by the UI.
if memory_block is None:
memory.observe_plan(plan) # grows-with-you: learn recurring contacts
events.emit("decision", f"{len(plan.events)} event(s) detected", events=len(plan.events))
if busy:
from calendar_out.freebusy import annotate_conflicts # lazy: avoid cycle
plan = annotate_conflicts(plan, busy)
yield (json.dumps(plan.model_dump(), indent=2), plan)
_TIME_RE = re.compile(r"\b(\d{1,2})(?::(\d{2}))?\s*(am|pm)?\b", re.IGNORECASE)
_TIME_LABEL_RE = re.compile(r"(?im)^\s*time\s*[:\-]\s*(.+)$")
_MONTH_DATE_RE = re.compile(
r"\b(?:jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|"
r"aug(?:ust)?|sep(?:t(?:ember)?)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)\.?\s+"
r"\d{1,2}(?:st|nd|rd|th)?(?:,?\s*\d{4})?\b", re.IGNORECASE)
_WEEKDAY_RE = re.compile(
r"\b(monday|tuesday|wednesday|thursday|friday|saturday|sunday)\b", re.IGNORECASE)
_LOCATION_RE = re.compile(
r"(?i)^\s*(?:(?:location|where|address)\s*[:\-]|\U0001F4CD)\s*(.*)$")
_LABEL_LINE_RE = re.compile(r"^\s*[A-Za-z][A-Za-z ]{0,20}:\s") # "Time: ...", "Notes: ..."
_DURATION_RE = re.compile(r"(?im)^\s*duration\s*[:\-]\s*(.*)$")
_EARLY_RE = re.compile(r"(?i)arrive\s+(\d{1,3})\s*min(?:ute)?s?\s+early")
_REMIND_EXPLICIT_RE = re.compile(
r"(?i)\b(?:remind(?:er)?|notify|alert)\s*(?:me\s+)?(?:for\s+)?"
r"(\d{1,3})\s*(min(?:ute)?s?|h(?:ou)?rs?)\s*(?:before|ahead|prior|early)")
_MEDICAL_RE = re.compile(
r"(?i)\b(?:doctor|dr\b\.?|clinic|dentist|dental|pediatric\w*|physician|"
r"medical|check-?up|primary\s+care|intake\s+forms?)")
_PARTY_RE = re.compile( # "party of 4" is a group size, not a party
r"(?i)\b(?:birthday|bday)\b|\bparty\b(?!\s+of\s+\d)")
_CARPOOL_SCHOOL_RE = re.compile(r"(?i)\bcarpool\w*\b|\bschool\b|drive\s+the\s+kids")
def _find_time(thread: str) -> Optional[tuple[int, int]]:
"""First plausible clock time, or None. A bare integer ("June 22", "112A")
is not a time — a match needs a minute component or an am/pm marker."""
label = _TIME_LABEL_RE.search(thread)
scope = label.group(1) if label else thread
for m in _TIME_RE.finditer(scope):
if not (m.group(2) or m.group(3)):
continue
hour, minute = int(m.group(1)), int(m.group(2) or 0)
if hour > 23 or minute > 59:
continue
mer = (m.group(3) or "").lower()
if mer == "pm" and hour < 12:
hour += 12
elif mer == "am" and hour == 12:
hour = 0
return hour, minute
return None
def _find_date(thread: str, now: datetime):
"""Resolve the event's day: explicit date > today/tomorrow > weekday > tomorrow."""
m = _MONTH_DATE_RE.search(thread)
if m:
try:
return dtparser.parse(m.group(0), default=now).date()
except (ValueError, OverflowError):
pass
if re.search(r"\btomorrow\b", thread, re.IGNORECASE):
return (now + timedelta(days=1)).date()
if re.search(r"\btoday\b|\btonight\b", thread, re.IGNORECASE):
return now.date()
m = _WEEKDAY_RE.search(thread)
if m:
try:
return dtparser.parse(m.group(1), default=now).date() # next-or-same day
except (ValueError, OverflowError):
pass
return (now + timedelta(days=1)).date()
def _find_location(lines: list[str]) -> tuple[Optional[str], set[int]]:
"""A "Location:" line plus continuation lines (a wrapped address) until a
blank line or the next "Label:" line. Returns (joined location, line idxs)."""
for i, line in enumerate(lines):
m = _LOCATION_RE.match(line)
if not m:
continue
parts, used = [m.group(1).strip()], {i}
for j in range(i + 1, len(lines)):
nxt = lines[j].strip()
if not nxt or nxt.startswith("(") or _LABEL_LINE_RE.match(lines[j]):
break
parts.append(nxt)
used.add(j)
loc = ", ".join(p for p in parts if p)
return (loc or None), used
return None, set()
def _find_duration_minutes(thread: str) -> Optional[int]:
m = _DURATION_RE.search(thread)
if m:
num = re.search(r"\d+", m.group(1))
if num:
return int(num.group(0))
return None
def _reminder_minutes(thread: str) -> int:
"""Notification lead time: an explicit ask wins, else event-type defaults
(medical 60, party 30, carpool/school 45 — checked in that order), else 30."""
m = _REMIND_EXPLICIT_RE.search(thread)
if m:
n = int(m.group(1))
return n * 60 if m.group(2).lower().startswith("h") else n
if _MEDICAL_RE.search(thread):
return 60
if _PARTY_RE.search(thread):
return 30
if _CARPOOL_SCHOOL_RE.search(thread):
return 45
return 30
def _is_date_line(line: str, now: datetime) -> bool:
try:
dtparser.parse(line, default=now) # non-fuzzy: chatter raises ParserError
return True
except (ValueError, OverflowError):
return False
def _pick_title(lines: list[str], now: datetime, location_idx: set[int]) -> str:
nonempty = [(i, ln.strip()) for i, ln in enumerate(lines) if ln.strip()]
if not nonempty:
return "Event"
first_i, first = nonempty[0]
if not _is_date_line(first, now):
return first[:60]
# First line is just the date — find a more descriptive line instead.
for i, ln in nonempty[1:]:
if i in location_idx or _LABEL_LINE_RE.match(ln) or ln.startswith("("):
continue
if _is_date_line(ln, now):
continue
return ln[:60]
return "Appointment"
def _stub_plan(thread: str, now: datetime) -> ActionPlan:
"""Heuristic ActionPlan so phases without a model still demo end to end."""
time_found = _find_time(thread)
if not time_found:
return ActionPlan(reasoning="No time found.", reply_draft="Got it, thanks!")
hour, minute = time_found
lines = thread.strip().splitlines()
location, loc_idx = _find_location(lines)
day = _find_date(thread, now)
appt = now.replace(year=day.year, month=day.month, day=day.day,
hour=hour, minute=minute, second=0, microsecond=0)
duration = _find_duration_minutes(thread) or 60
# "Arrive N minutes early" -> start at the ARRIVAL time; the end (and the
# notes) stay anchored to the stated appointment time.
early = _EARLY_RE.search(thread)
start = appt - timedelta(minutes=int(early.group(1))) if early else appt
notes = (f"Appointment at {appt.strftime('%H:%M')}; arrive {early.group(1)} min early"
if early else "(stub agent — wire the model to replace this)")
return ActionPlan(
reasoning="(stub) parsed time/date/location heuristically.",
events=[
Event(
title=_pick_title(lines, now, loc_idx),
start=start.isoformat(),
end=(appt + timedelta(minutes=duration)).isoformat(),
location=location,
reminder_minutes=_reminder_minutes(thread),
notes=notes,
)
],
reply_draft="Sounds good, see you then!",
)