from __future__ import annotations import json import sqlite3 from dataclasses import asdict, dataclass, field from typing import Any, Dict, Optional @dataclass(slots=True) class UserData: customer_id: Optional[str] = None customer_name: Optional[str] = None customer_email: Optional[str] = None customer_phone: Optional[str] = None purpose_call: Optional[str] = None is_paused: bool = False Timezone: Optional[str] = None agent_memories: Dict[str, Any] = field(default_factory=dict) access_token: Optional[str] = None company_name: Optional[str] = None Booked_appointment: Optional[str] = None agents: Dict[str, Any] = field(default_factory=dict) prev_agent: Any | None = None summary: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return asdict(self) class UserStateManager: """Manages per-thread user/customer state for the multi-agent flow with SQLite persistence.""" def __init__(self, db_path: str = "user_state.db") -> None: self._states: Dict[str, UserData] = {} self.db_path = db_path self.conn = sqlite3.connect(db_path, check_same_thread=False) self._init_db() self._load_all_states() def _init_db(self) -> None: """Initialize the user state database table.""" cursor = self.conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS user_states ( thread_id TEXT PRIMARY KEY, state_data TEXT NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) self.conn.commit() def _load_all_states(self) -> None: """Load all user states from database into memory on startup.""" cursor = self.conn.cursor() cursor.execute("SELECT thread_id, state_data FROM user_states") for row in cursor.fetchall(): thread_id, state_json = row try: state_dict = json.loads(state_json) self._states[thread_id] = UserData(**state_dict) except Exception as e: print(f"⚠️ Failed to load state for thread {thread_id}: {e}") def _save_state(self, thread_id: str) -> None: """Persist a thread's state to the database.""" state = self._states.get(thread_id) if state is None: return try: state_json = json.dumps(state.to_dict()) cursor = self.conn.cursor() cursor.execute(""" INSERT INTO user_states (thread_id, state_data, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP) ON CONFLICT(thread_id) DO UPDATE SET state_data = excluded.state_data, updated_at = CURRENT_TIMESTAMP """, (thread_id, state_json)) self.conn.commit() except Exception as e: print(f"⚠️ Failed to save state for thread {thread_id}: {e}") def _create_default_state(self) -> UserData: return UserData() def get_user(self, thread_id: str) -> UserData: if thread_id not in self._states: self._states[thread_id] = self._create_default_state() self._save_state(thread_id) return self._states[thread_id] # -- Mutators -------------------------------------------------------- def set_customer_info( self, thread_id: str, *, customer_id: Optional[str] = None, name: Optional[str] = None, email: Optional[str] = None, phone: Optional[str] = None, company_name: Optional[str] = None, ) -> None: data = self.get_user(thread_id) if customer_id is not None: data.customer_id = customer_id if name is not None: data.customer_name = name if email is not None: data.customer_email = email if phone is not None: data.customer_phone = phone if company_name is not None: data.company_name = company_name self._save_state(thread_id) def set_purpose(self, thread_id: str, purpose: Optional[str]) -> None: self.get_user(thread_id).purpose_call = purpose self._save_state(thread_id) def set_timezone(self, thread_id: str, timezone: Optional[str]) -> None: self.get_user(thread_id).Timezone = timezone self._save_state(thread_id) def set_paused(self, thread_id: str, paused: bool) -> None: self.get_user(thread_id).is_paused = bool(paused) self._save_state(thread_id) def set_access_token(self, thread_id: str, token: Optional[str]) -> None: self.get_user(thread_id).access_token = token self._save_state(thread_id) def set_booked_appointment(self, thread_id: str, value: Optional[str]) -> None: self.get_user(thread_id).Booked_appointment = value self._save_state(thread_id) def get_cached_access_token(self, thread_id: str) -> Optional[str]: """Get the access token if it exists and is less than 12 hours old.""" cursor = self.conn.cursor() cursor.execute(""" SELECT state_data, updated_at FROM user_states WHERE thread_id = ? """, (thread_id,)) row = cursor.fetchone() if not row: return None state_json, updated_at_str = row from datetime import datetime, timezone, timedelta try: # updated_at is stored in UTC by the DB updated_at = datetime.fromisoformat(updated_at_str.replace(' ', 'T') + "+00:00") if datetime.now(timezone.utc) - updated_at > timedelta(hours=12): print(f"🕒 Access token for {thread_id} expired (> 12h)") return None state_dict = json.loads(state_json) return state_dict.get("access_token") except Exception as e: print(f"⚠️ Error checking token expiration: {e}") return None def update_access_token(self, thread_id: str, token: str) -> None: """Update the access token and reset the updated_at timestamp.""" data = self.get_user(thread_id) data.access_token = token # _save_state updates updated_at to CURRENT_TIMESTAMP self._save_state(thread_id) def set_summary(self, thread_id: str, summary: str) -> None: self.get_user(thread_id).summary = summary self._save_state(thread_id) def get_summary(self, thread_id: str) -> str: return self.get_user(thread_id).summary or "" # Agent registry helpers def register_agent(self, thread_id: str, key: str, agent: Any) -> None: self.get_user(thread_id).agents[key] = agent self._save_state(thread_id) def get_agent(self, thread_id: str, key: str) -> Any: return self.get_user(thread_id).agents.get(key) def set_prev_agent(self, thread_id: str, agent: Any | None) -> None: self.get_user(thread_id).prev_agent = agent self._save_state(thread_id) # Memories helpers def remember(self, thread_id: str, key: str, value: Any) -> None: self.get_user(thread_id).agent_memories[key] = value self._save_state(thread_id) def recall(self, thread_id: str, key: str, default: Any = None) -> Any: return self.get_user(thread_id).agent_memories.get(key, default) # -- Serialization --------------------------------------------------- def to_dict(self, thread_id: str) -> Dict[str, Any]: return self.get_user(thread_id).to_dict()