Spaces:
Sleeping
Sleeping
| # chat_manager.py - Chat Session Management System | |
| import json | |
| import os | |
| import uuid | |
| from dataclasses import dataclass, asdict | |
| from typing import List, Optional, Dict, Any | |
| from datetime import datetime | |
| from pathlib import Path | |
| class ChatMessage: | |
| """Individual chat message structure""" | |
| message_id: str | |
| role: str # 'user' or 'assistant' | |
| content: str | |
| timestamp: str | |
| rating: Optional[int] = None # 1 for thumbs up, -1 for thumbs down, None for no rating | |
| is_bookmarked: bool = False | |
| source_documents: List[str] = None | |
| def __post_init__(self): | |
| if self.source_documents is None: | |
| self.source_documents = [] | |
| class ChatSession: | |
| """Chat session data structure""" | |
| session_id: str | |
| user_id: str | |
| title: str | |
| created_at: str | |
| updated_at: str | |
| messages: List[ChatMessage] = None | |
| is_archived: bool = False | |
| tags: List[str] = None | |
| def __post_init__(self): | |
| if self.messages is None: | |
| self.messages = [] | |
| if self.tags is None: | |
| self.tags = [] | |
| class ChatManager: | |
| """Manages chat sessions and messages""" | |
| def __init__(self, data_dir: str): | |
| self.data_dir = Path(data_dir) | |
| self.data_dir.mkdir(exist_ok=True) | |
| self.sessions_file = self.data_dir / "sessions.json" | |
| self.ensure_sessions_file() | |
| def ensure_sessions_file(self): | |
| """Ensure sessions file exists""" | |
| if not self.sessions_file.exists(): | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump({}, f) | |
| def create_session(self, user_id: str, title: str = None) -> str: | |
| """Create a new chat session""" | |
| session_id = str(uuid.uuid4()) | |
| timestamp = datetime.now().isoformat() | |
| if not title: | |
| title = f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}" | |
| session = ChatSession( | |
| session_id=session_id, | |
| user_id=user_id, | |
| title=title, | |
| created_at=timestamp, | |
| updated_at=timestamp | |
| ) | |
| try: | |
| sessions = self.load_all_sessions() | |
| sessions[session_id] = asdict(session) | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return session_id | |
| except Exception as e: | |
| raise Exception(f"Failed to create session: {str(e)}") | |
| def load_all_sessions(self) -> Dict[str, Dict]: | |
| """Load all sessions from storage""" | |
| try: | |
| with open(self.sessions_file, 'r') as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| return {} | |
| def get_session(self, session_id: str) -> Optional[ChatSession]: | |
| """Get chat session by ID""" | |
| sessions = self.load_all_sessions() | |
| session_data = sessions.get(session_id) | |
| if session_data: | |
| # Convert message dictionaries back to ChatMessage objects | |
| messages = [] | |
| for msg_data in session_data.get('messages', []): | |
| messages.append(ChatMessage(**msg_data)) | |
| session_data['messages'] = messages | |
| return ChatSession(**session_data) | |
| return None | |
| def get_user_sessions(self, user_id: str, include_archived: bool = False) -> List[ChatSession]: | |
| """Get all sessions for a user""" | |
| sessions = self.load_all_sessions() | |
| user_sessions = [] | |
| for session_data in sessions.values(): | |
| if session_data.get('user_id') == user_id: | |
| if include_archived or not session_data.get('is_archived', False): | |
| # Convert message dictionaries back to ChatMessage objects | |
| messages = [] | |
| for msg_data in session_data.get('messages', []): | |
| messages.append(ChatMessage(**msg_data)) | |
| session_data['messages'] = messages | |
| user_sessions.append(ChatSession(**session_data)) | |
| # Sort by updated_at descending | |
| user_sessions.sort(key=lambda x: x.updated_at, reverse=True) | |
| return user_sessions | |
| def add_message(self, session_id: str, role: str, content: str, source_documents: List[str] = None) -> str: | |
| """Add a message to a chat session""" | |
| message_id = str(uuid.uuid4()) | |
| timestamp = datetime.now().isoformat() | |
| message = ChatMessage( | |
| message_id=message_id, | |
| role=role, | |
| content=content, | |
| timestamp=timestamp, | |
| source_documents=source_documents or [] | |
| ) | |
| try: | |
| sessions = self.load_all_sessions() | |
| if session_id not in sessions: | |
| raise ValueError(f"Session {session_id} not found") | |
| # Convert message to dict for storage | |
| message_dict = asdict(message) | |
| sessions[session_id]['messages'].append(message_dict) | |
| sessions[session_id]['updated_at'] = timestamp | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return message_id | |
| except Exception as e: | |
| raise Exception(f"Failed to add message: {str(e)}") | |
| def rate_message(self, session_id: str, message_id: str, rating: int) -> bool: | |
| """Rate a message (1 for thumbs up, -1 for thumbs down)""" | |
| try: | |
| sessions = self.load_all_sessions() | |
| if session_id not in sessions: | |
| return False | |
| for message in sessions[session_id]['messages']: | |
| if message['message_id'] == message_id: | |
| message['rating'] = rating | |
| sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def bookmark_message(self, session_id: str, message_id: str, is_bookmarked: bool = True) -> bool: | |
| """Bookmark or unbookmark a message""" | |
| try: | |
| sessions = self.load_all_sessions() | |
| if session_id not in sessions: | |
| return False | |
| for message in sessions[session_id]['messages']: | |
| if message['message_id'] == message_id: | |
| message['is_bookmarked'] = is_bookmarked | |
| sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def get_bookmarked_messages(self, user_id: str) -> List[Dict[str, Any]]: | |
| """Get all bookmarked messages for a user""" | |
| sessions = self.load_all_sessions() | |
| bookmarked = [] | |
| for session_data in sessions.values(): | |
| if session_data.get('user_id') == user_id: | |
| for message in session_data.get('messages', []): | |
| if message.get('is_bookmarked', False): | |
| bookmarked.append({ | |
| 'session_id': session_data['session_id'], | |
| 'session_title': session_data['title'], | |
| 'message': message, | |
| 'timestamp': message['timestamp'] | |
| }) | |
| # Sort by timestamp descending | |
| bookmarked.sort(key=lambda x: x['timestamp'], reverse=True) | |
| return bookmarked | |
| def update_session_title(self, session_id: str, title: str) -> bool: | |
| """Update session title""" | |
| try: | |
| sessions = self.load_all_sessions() | |
| if session_id not in sessions: | |
| return False | |
| sessions[session_id]['title'] = title | |
| sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return True | |
| except Exception: | |
| return False | |
| def archive_session(self, session_id: str, is_archived: bool = True) -> bool: | |
| """Archive or unarchive a session""" | |
| try: | |
| sessions = self.load_all_sessions() | |
| if session_id not in sessions: | |
| return False | |
| sessions[session_id]['is_archived'] = is_archived | |
| sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return True | |
| except Exception: | |
| return False | |
| def delete_session(self, session_id: str) -> bool: | |
| """Delete a chat session""" | |
| try: | |
| sessions = self.load_all_sessions() | |
| if session_id in sessions: | |
| del sessions[session_id] | |
| with open(self.sessions_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def export_chat_history(self, user_id: str, session_id: str = None) -> Dict[str, Any]: | |
| """Export chat history for a user or specific session""" | |
| if session_id: | |
| session = self.get_session(session_id) | |
| if session and session.user_id == user_id: | |
| return { | |
| 'export_type': 'single_session', | |
| 'session': asdict(session), | |
| 'exported_at': datetime.now().isoformat() | |
| } | |
| else: | |
| sessions = self.get_user_sessions(user_id, include_archived=True) | |
| return { | |
| 'export_type': 'all_sessions', | |
| 'sessions': [asdict(session) for session in sessions], | |
| 'exported_at': datetime.now().isoformat(), | |
| 'total_sessions': len(sessions) | |
| } | |
| return {} | |
| def get_chat_statistics(self, user_id: str) -> Dict[str, Any]: | |
| """Get chat statistics for a user""" | |
| sessions = self.get_user_sessions(user_id, include_archived=True) | |
| total_messages = 0 | |
| total_user_messages = 0 | |
| total_assistant_messages = 0 | |
| bookmarked_count = 0 | |
| rated_messages = {'positive': 0, 'negative': 0} | |
| for session in sessions: | |
| total_messages += len(session.messages) | |
| for message in session.messages: | |
| if message.role == 'user': | |
| total_user_messages += 1 | |
| else: | |
| total_assistant_messages += 1 | |
| if message.is_bookmarked: | |
| bookmarked_count += 1 | |
| if message.rating == 1: | |
| rated_messages['positive'] += 1 | |
| elif message.rating == -1: | |
| rated_messages['negative'] += 1 | |
| return { | |
| 'total_sessions': len(sessions), | |
| 'total_messages': total_messages, | |
| 'user_messages': total_user_messages, | |
| 'assistant_messages': total_assistant_messages, | |
| 'bookmarked_messages': bookmarked_count, | |
| 'message_ratings': rated_messages, | |
| 'average_messages_per_session': total_messages / len(sessions) if sessions else 0 | |
| } |