Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import sqlite3 | |
| from dataclasses import asdict, dataclass, field | |
| from typing import Any, Dict, Optional | |
| class UserData: | |
| customer_id: Optional[str] = None | |
| customer_name: Optional[str] = None | |
| customer_email: Optional[str] = None | |
| customer_phone: Optional[str] = None | |
| purpose_call: Optional[str] = None | |
| is_paused: bool = False | |
| Timezone: Optional[str] = None | |
| agent_memories: Dict[str, Any] = field(default_factory=dict) | |
| access_token: Optional[str] = None | |
| company_name: Optional[str] = None | |
| Booked_appointment: Optional[str] = None | |
| agents: Dict[str, Any] = field(default_factory=dict) | |
| prev_agent: Any | None = None | |
| summary: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| return asdict(self) | |
| class UserStateManager: | |
| """Manages per-thread user/customer state for the multi-agent flow with SQLite persistence.""" | |
| def __init__(self, db_path: str = "user_state.db") -> None: | |
| self._states: Dict[str, UserData] = {} | |
| self.db_path = db_path | |
| self.conn = sqlite3.connect(db_path, check_same_thread=False) | |
| self._init_db() | |
| self._load_all_states() | |
| def _init_db(self) -> None: | |
| """Initialize the user state database table.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_states ( | |
| thread_id TEXT PRIMARY KEY, | |
| state_data TEXT NOT NULL, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| self.conn.commit() | |
| def _load_all_states(self) -> None: | |
| """Load all user states from database into memory on startup.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT thread_id, state_data FROM user_states") | |
| for row in cursor.fetchall(): | |
| thread_id, state_json = row | |
| try: | |
| state_dict = json.loads(state_json) | |
| self._states[thread_id] = UserData(**state_dict) | |
| except Exception as e: | |
| print(f"⚠️ Failed to load state for thread {thread_id}: {e}") | |
| def _save_state(self, thread_id: str) -> None: | |
| """Persist a thread's state to the database.""" | |
| state = self._states.get(thread_id) | |
| if state is None: | |
| return | |
| try: | |
| state_json = json.dumps(state.to_dict()) | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO user_states (thread_id, state_data, updated_at) | |
| VALUES (?, ?, CURRENT_TIMESTAMP) | |
| ON CONFLICT(thread_id) DO UPDATE SET | |
| state_data = excluded.state_data, | |
| updated_at = CURRENT_TIMESTAMP | |
| """, (thread_id, state_json)) | |
| self.conn.commit() | |
| except Exception as e: | |
| print(f"⚠️ Failed to save state for thread {thread_id}: {e}") | |
| def _create_default_state(self) -> UserData: | |
| return UserData() | |
| def get_user(self, thread_id: str) -> UserData: | |
| if thread_id not in self._states: | |
| self._states[thread_id] = self._create_default_state() | |
| self._save_state(thread_id) | |
| return self._states[thread_id] | |
| # -- Mutators -------------------------------------------------------- | |
| def set_customer_info( | |
| self, | |
| thread_id: str, | |
| *, | |
| customer_id: Optional[str] = None, | |
| name: Optional[str] = None, | |
| email: Optional[str] = None, | |
| phone: Optional[str] = None, | |
| company_name: Optional[str] = None, | |
| ) -> None: | |
| data = self.get_user(thread_id) | |
| if customer_id is not None: | |
| data.customer_id = customer_id | |
| if name is not None: | |
| data.customer_name = name | |
| if email is not None: | |
| data.customer_email = email | |
| if phone is not None: | |
| data.customer_phone = phone | |
| if company_name is not None: | |
| data.company_name = company_name | |
| self._save_state(thread_id) | |
| def set_purpose(self, thread_id: str, purpose: Optional[str]) -> None: | |
| self.get_user(thread_id).purpose_call = purpose | |
| self._save_state(thread_id) | |
| def set_timezone(self, thread_id: str, timezone: Optional[str]) -> None: | |
| self.get_user(thread_id).Timezone = timezone | |
| self._save_state(thread_id) | |
| def set_paused(self, thread_id: str, paused: bool) -> None: | |
| self.get_user(thread_id).is_paused = bool(paused) | |
| self._save_state(thread_id) | |
| def set_access_token(self, thread_id: str, token: Optional[str]) -> None: | |
| self.get_user(thread_id).access_token = token | |
| self._save_state(thread_id) | |
| def set_booked_appointment(self, thread_id: str, value: Optional[str]) -> None: | |
| self.get_user(thread_id).Booked_appointment = value | |
| self._save_state(thread_id) | |
| def get_cached_access_token(self, thread_id: str) -> Optional[str]: | |
| """Get the access token if it exists and is less than 12 hours old.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| SELECT state_data, updated_at FROM user_states WHERE thread_id = ? | |
| """, (thread_id,)) | |
| row = cursor.fetchone() | |
| if not row: | |
| return None | |
| state_json, updated_at_str = row | |
| from datetime import datetime, timezone, timedelta | |
| try: | |
| # updated_at is stored in UTC by the DB | |
| updated_at = datetime.fromisoformat(updated_at_str.replace(' ', 'T') + "+00:00") | |
| if datetime.now(timezone.utc) - updated_at > timedelta(hours=12): | |
| print(f"🕒 Access token for {thread_id} expired (> 12h)") | |
| return None | |
| state_dict = json.loads(state_json) | |
| return state_dict.get("access_token") | |
| except Exception as e: | |
| print(f"⚠️ Error checking token expiration: {e}") | |
| return None | |
| def update_access_token(self, thread_id: str, token: str) -> None: | |
| """Update the access token and reset the updated_at timestamp.""" | |
| data = self.get_user(thread_id) | |
| data.access_token = token | |
| # _save_state updates updated_at to CURRENT_TIMESTAMP | |
| self._save_state(thread_id) | |
| def set_summary(self, thread_id: str, summary: str) -> None: | |
| self.get_user(thread_id).summary = summary | |
| self._save_state(thread_id) | |
| def get_summary(self, thread_id: str) -> str: | |
| return self.get_user(thread_id).summary or "" | |
| # Agent registry helpers | |
| def register_agent(self, thread_id: str, key: str, agent: Any) -> None: | |
| self.get_user(thread_id).agents[key] = agent | |
| self._save_state(thread_id) | |
| def get_agent(self, thread_id: str, key: str) -> Any: | |
| return self.get_user(thread_id).agents.get(key) | |
| def set_prev_agent(self, thread_id: str, agent: Any | None) -> None: | |
| self.get_user(thread_id).prev_agent = agent | |
| self._save_state(thread_id) | |
| # Memories helpers | |
| def remember(self, thread_id: str, key: str, value: Any) -> None: | |
| self.get_user(thread_id).agent_memories[key] = value | |
| self._save_state(thread_id) | |
| def recall(self, thread_id: str, key: str, default: Any = None) -> Any: | |
| return self.get_user(thread_id).agent_memories.get(key, default) | |
| # -- Serialization --------------------------------------------------- | |
| def to_dict(self, thread_id: str) -> Dict[str, Any]: | |
| return self.get_user(thread_id).to_dict() | |