| """Chat history manager for per-user conversation storage.""" |
|
|
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from collections import defaultdict |
|
|
|
|
| @dataclass |
| class ChatMessage: |
| """Single chat message.""" |
|
|
| role: str |
| content: str |
| timestamp: datetime = field(default_factory=datetime.now) |
|
|
|
|
| @dataclass |
| class ChatSession: |
| """Chat session with history.""" |
|
|
| session_id: str |
| messages: list[ChatMessage] = field(default_factory=list) |
| created_at: datetime = field(default_factory=datetime.now) |
|
|
| def add_message(self, role: str, content: str) -> None: |
| """Add a message to this session.""" |
| self.messages.append(ChatMessage(role=role, content=content)) |
|
|
| def get_history_text(self, max_messages: int = 10) -> str: |
| """Get formatted history for prompt injection.""" |
| recent = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages |
| if not recent: |
| return "" |
|
|
| lines = [] |
| for msg in recent: |
| prefix = "User" if msg.role == "user" else "Assistant" |
| lines.append(f"{prefix}: {msg.content}") |
|
|
| return "\n".join(lines) |
|
|
|
|
| class ChatHistoryManager: |
| """ |
| Manages chat history per user with multiple sessions. |
| |
| Each user can have up to max_sessions active sessions. |
| Oldest sessions are removed when limit is exceeded. |
| """ |
|
|
| def __init__(self, max_sessions_per_user: int = 3, max_messages_per_session: int = 20): |
| """ |
| Initialize the chat history manager. |
| |
| Args: |
| max_sessions_per_user: Maximum sessions to keep per user (default 3) |
| max_messages_per_session: Maximum messages per session (default 20) |
| """ |
| self.max_sessions = max_sessions_per_user |
| self.max_messages = max_messages_per_session |
| self._sessions: dict[str, dict[str, ChatSession]] = defaultdict(dict) |
|
|
| def get_or_create_session(self, user_id: str, session_id: str | None = None) -> ChatSession: |
| """ |
| Get existing session or create a new one. |
| |
| Args: |
| user_id: User identifier |
| session_id: Optional session ID (uses "default" if not provided) |
| |
| Returns: |
| ChatSession instance |
| """ |
| session_id = session_id or "default" |
| user_sessions = self._sessions[user_id] |
|
|
| if session_id not in user_sessions: |
| |
| user_sessions[session_id] = ChatSession(session_id=session_id) |
|
|
| |
| if len(user_sessions) > self.max_sessions: |
| |
| oldest_id = min( |
| user_sessions.keys(), |
| key=lambda k: user_sessions[k].created_at |
| ) |
| del user_sessions[oldest_id] |
|
|
| return user_sessions[session_id] |
|
|
| def add_message( |
| self, |
| user_id: str, |
| role: str, |
| content: str, |
| session_id: str | None = None, |
| ) -> None: |
| """ |
| Add a message to user's session. |
| |
| Args: |
| user_id: User identifier |
| role: "user" or "assistant" |
| content: Message content |
| session_id: Optional session ID |
| """ |
| session = self.get_or_create_session(user_id, session_id) |
| session.add_message(role, content) |
|
|
| |
| if len(session.messages) > self.max_messages: |
| session.messages = session.messages[-self.max_messages:] |
|
|
| def get_history( |
| self, |
| user_id: str, |
| session_id: str | None = None, |
| max_messages: int = 10, |
| ) -> str: |
| """ |
| Get formatted chat history for prompt. |
| |
| Args: |
| user_id: User identifier |
| session_id: Optional session ID |
| max_messages: Maximum messages to include |
| |
| Returns: |
| Formatted history string |
| """ |
| session = self.get_or_create_session(user_id, session_id) |
| return session.get_history_text(max_messages) |
|
|
| def get_messages( |
| self, |
| user_id: str, |
| session_id: str | None = None, |
| ) -> list[ChatMessage]: |
| """Get all messages for a session.""" |
| session = self.get_or_create_session(user_id, session_id) |
| return session.messages |
|
|
| def clear_session(self, user_id: str, session_id: str | None = None) -> None: |
| """Clear a specific session.""" |
| session_id = session_id or "default" |
| if user_id in self._sessions and session_id in self._sessions[user_id]: |
| del self._sessions[user_id][session_id] |
|
|
| def clear_all_sessions(self, user_id: str) -> None: |
| """Clear all sessions for a user.""" |
| if user_id in self._sessions: |
| self._sessions[user_id].clear() |
|
|
| def get_session_ids(self, user_id: str) -> list[str]: |
| """Get all session IDs for a user.""" |
| return list(self._sessions.get(user_id, {}).keys()) |
|
|
|
|
| |
| chat_history = ChatHistoryManager(max_sessions_per_user=3, max_messages_per_session=20) |
|
|