agent_backend / app /user_state.py
GoutamSachdev's picture
ok
de6fb09 verified
Raw
History Blame Contribute Delete
7.83 kB
from __future__ import annotations
import json
import sqlite3
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
@dataclass(slots=True)
class UserData:
customer_id: Optional[str] = None
customer_name: Optional[str] = None
customer_email: Optional[str] = None
customer_phone: Optional[str] = None
purpose_call: Optional[str] = None
is_paused: bool = False
Timezone: Optional[str] = None
agent_memories: Dict[str, Any] = field(default_factory=dict)
access_token: Optional[str] = None
company_name: Optional[str] = None
Booked_appointment: Optional[str] = None
agents: Dict[str, Any] = field(default_factory=dict)
prev_agent: Any | None = None
summary: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class UserStateManager:
"""Manages per-thread user/customer state for the multi-agent flow with SQLite persistence."""
def __init__(self, db_path: str = "user_state.db") -> None:
self._states: Dict[str, UserData] = {}
self.db_path = db_path
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self._init_db()
self._load_all_states()
def _init_db(self) -> None:
"""Initialize the user state database table."""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_states (
thread_id TEXT PRIMARY KEY,
state_data TEXT NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.conn.commit()
def _load_all_states(self) -> None:
"""Load all user states from database into memory on startup."""
cursor = self.conn.cursor()
cursor.execute("SELECT thread_id, state_data FROM user_states")
for row in cursor.fetchall():
thread_id, state_json = row
try:
state_dict = json.loads(state_json)
self._states[thread_id] = UserData(**state_dict)
except Exception as e:
print(f"⚠️ Failed to load state for thread {thread_id}: {e}")
def _save_state(self, thread_id: str) -> None:
"""Persist a thread's state to the database."""
state = self._states.get(thread_id)
if state is None:
return
try:
state_json = json.dumps(state.to_dict())
cursor = self.conn.cursor()
cursor.execute("""
INSERT INTO user_states (thread_id, state_data, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(thread_id) DO UPDATE SET
state_data = excluded.state_data,
updated_at = CURRENT_TIMESTAMP
""", (thread_id, state_json))
self.conn.commit()
except Exception as e:
print(f"⚠️ Failed to save state for thread {thread_id}: {e}")
def _create_default_state(self) -> UserData:
return UserData()
def get_user(self, thread_id: str) -> UserData:
if thread_id not in self._states:
self._states[thread_id] = self._create_default_state()
self._save_state(thread_id)
return self._states[thread_id]
# -- Mutators --------------------------------------------------------
def set_customer_info(
self,
thread_id: str,
*,
customer_id: Optional[str] = None,
name: Optional[str] = None,
email: Optional[str] = None,
phone: Optional[str] = None,
company_name: Optional[str] = None,
) -> None:
data = self.get_user(thread_id)
if customer_id is not None:
data.customer_id = customer_id
if name is not None:
data.customer_name = name
if email is not None:
data.customer_email = email
if phone is not None:
data.customer_phone = phone
if company_name is not None:
data.company_name = company_name
self._save_state(thread_id)
def set_purpose(self, thread_id: str, purpose: Optional[str]) -> None:
self.get_user(thread_id).purpose_call = purpose
self._save_state(thread_id)
def set_timezone(self, thread_id: str, timezone: Optional[str]) -> None:
self.get_user(thread_id).Timezone = timezone
self._save_state(thread_id)
def set_paused(self, thread_id: str, paused: bool) -> None:
self.get_user(thread_id).is_paused = bool(paused)
self._save_state(thread_id)
def set_access_token(self, thread_id: str, token: Optional[str]) -> None:
self.get_user(thread_id).access_token = token
self._save_state(thread_id)
def set_booked_appointment(self, thread_id: str, value: Optional[str]) -> None:
self.get_user(thread_id).Booked_appointment = value
self._save_state(thread_id)
def get_cached_access_token(self, thread_id: str) -> Optional[str]:
"""Get the access token if it exists and is less than 12 hours old."""
cursor = self.conn.cursor()
cursor.execute("""
SELECT state_data, updated_at FROM user_states WHERE thread_id = ?
""", (thread_id,))
row = cursor.fetchone()
if not row:
return None
state_json, updated_at_str = row
from datetime import datetime, timezone, timedelta
try:
# updated_at is stored in UTC by the DB
updated_at = datetime.fromisoformat(updated_at_str.replace(' ', 'T') + "+00:00")
if datetime.now(timezone.utc) - updated_at > timedelta(hours=12):
print(f"🕒 Access token for {thread_id} expired (> 12h)")
return None
state_dict = json.loads(state_json)
return state_dict.get("access_token")
except Exception as e:
print(f"⚠️ Error checking token expiration: {e}")
return None
def update_access_token(self, thread_id: str, token: str) -> None:
"""Update the access token and reset the updated_at timestamp."""
data = self.get_user(thread_id)
data.access_token = token
# _save_state updates updated_at to CURRENT_TIMESTAMP
self._save_state(thread_id)
def set_summary(self, thread_id: str, summary: str) -> None:
self.get_user(thread_id).summary = summary
self._save_state(thread_id)
def get_summary(self, thread_id: str) -> str:
return self.get_user(thread_id).summary or ""
# Agent registry helpers
def register_agent(self, thread_id: str, key: str, agent: Any) -> None:
self.get_user(thread_id).agents[key] = agent
self._save_state(thread_id)
def get_agent(self, thread_id: str, key: str) -> Any:
return self.get_user(thread_id).agents.get(key)
def set_prev_agent(self, thread_id: str, agent: Any | None) -> None:
self.get_user(thread_id).prev_agent = agent
self._save_state(thread_id)
# Memories helpers
def remember(self, thread_id: str, key: str, value: Any) -> None:
self.get_user(thread_id).agent_memories[key] = value
self._save_state(thread_id)
def recall(self, thread_id: str, key: str, default: Any = None) -> Any:
return self.get_user(thread_id).agent_memories.get(key, default)
# -- Serialization ---------------------------------------------------
def to_dict(self, thread_id: str) -> Dict[str, Any]:
return self.get_user(thread_id).to_dict()