Spaces:
Paused
Paused
| import os | |
| import json | |
| from datetime import datetime | |
| import faiss | |
| import numpy as np | |
| import pickle | |
| from pathlib import Path | |
| import streamlit as st | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
| class PersistenceManager: | |
| def __init__(self, data_dir: str = "data"): | |
| """Initialize the persistence manager with paths for storing data. | |
| Args: | |
| data_dir: Base directory for data storage | |
| """ | |
| self.base_dir = Path(data_dir) | |
| self.vector_store_dir = self.base_dir / "vector_stores" | |
| self.chat_history_dir = self.base_dir / "chat_histories" | |
| self.chunks_dir = self.base_dir / "chunks" | |
| # Create necessary directories | |
| for directory in [self.vector_store_dir, self.chat_history_dir, self.chunks_dir]: | |
| directory.mkdir(parents=True, exist_ok=True) | |
| def save_vector_store(self, vector_store: Any, session_id: str) -> bool: | |
| """Save FAISS vector store and related metadata. | |
| Args: | |
| vector_store: FAISS vector store instance | |
| session_id: Unique identifier for the session | |
| """ | |
| try: | |
| # Create session-specific directory | |
| store_path = self.vector_store_dir / session_id | |
| store_path.mkdir(exist_ok=True) | |
| # Save the FAISS index | |
| faiss.write_index(vector_store.index, | |
| str(store_path / "index.faiss")) | |
| # Save the documents and their metadata | |
| with open(store_path / "docstore.pkl", "wb") as f: | |
| pickle.dump(vector_store.docstore, f) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving vector store: {str(e)}") | |
| return False | |
| def load_vector_store(self, session_id: str) -> Any: | |
| """Load FAISS vector store and related metadata. | |
| Args: | |
| session_id: Unique identifier for the session | |
| """ | |
| try: | |
| store_path = self.vector_store_dir / session_id | |
| if not store_path.exists(): | |
| return None | |
| # Load the FAISS index | |
| index = faiss.read_index(str(store_path / "index.faiss")) | |
| # Load the document store | |
| with open(store_path / "docstore.pkl", "rb") as f: | |
| docstore = pickle.load(f) | |
| # Recreate the vector store | |
| from langchain.vectorstores import FAISS | |
| vector_store = FAISS( | |
| embedding_function=st.session_state.embeddings, | |
| index=index, | |
| docstore=docstore, | |
| index_to_docstore_id=docstore.index_to_docstore_id | |
| ) | |
| return vector_store | |
| except Exception as e: | |
| st.error(f"Error loading vector store: {str(e)}") | |
| return None | |
| def save_chat_history( | |
| self, | |
| messages: List[BaseMessage], | |
| session_id: str, | |
| metadata: Dict[str, Any] = None | |
| ) -> bool: | |
| """Save chat history with metadata. | |
| Args: | |
| messages: List of chat messages | |
| session_id: Unique identifier for the chat session | |
| metadata: Additional metadata about the chat session | |
| """ | |
| try: | |
| # Convert messages to serializable format | |
| serialized_messages = [] | |
| for msg in messages: | |
| if isinstance(msg, (HumanMessage, AIMessage)): | |
| serialized_messages.append({ | |
| 'type': msg.__class__.__name__, | |
| 'content': msg.content, | |
| 'timestamp': datetime.now().isoformat() | |
| }) | |
| # Prepare chat data | |
| chat_data = { | |
| 'messages': serialized_messages, | |
| 'metadata': metadata or {}, | |
| 'last_updated': datetime.now().isoformat() | |
| } | |
| # Save to JSON file | |
| chat_file = self.chat_history_dir / f"{session_id}.json" | |
| with open(chat_file, 'w') as f: | |
| json.dump(chat_data, f, indent=2) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving chat history: {str(e)}") | |
| return False | |
| def load_chat_history(self, session_id: str) -> List[BaseMessage]: | |
| """Load chat history for a session. | |
| Args: | |
| session_id: Unique identifier for the chat session | |
| """ | |
| try: | |
| chat_file = self.chat_history_dir / f"{session_id}.json" | |
| if not chat_file.exists(): | |
| return [] | |
| with open(chat_file, 'r') as f: | |
| chat_data = json.load(f) | |
| # Convert back to message objects | |
| messages = [] | |
| for msg in chat_data['messages']: | |
| if msg['type'] == 'HumanMessage': | |
| messages.append(HumanMessage(content=msg['content'])) | |
| elif msg['type'] == 'AIMessage': | |
| messages.append(AIMessage(content=msg['content'])) | |
| return messages | |
| except Exception as e: | |
| st.error(f"Error loading chat history: {str(e)}") | |
| return [] | |
| def save_chunks( | |
| self, | |
| chunks: List[str], | |
| chunk_metadatas: List[Dict], | |
| session_id: str | |
| ) -> bool: | |
| """Save document chunks and their metadata. | |
| Args: | |
| chunks: List of text chunks | |
| chunk_metadatas: List of metadata dictionaries for each chunk | |
| session_id: Unique identifier for the session | |
| """ | |
| try: | |
| chunk_data = { | |
| 'chunks': chunks, | |
| 'metadatas': chunk_metadatas, | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| chunk_file = self.chunks_dir / f"{session_id}_chunks.pkl" | |
| with open(chunk_file, 'wb') as f: | |
| pickle.dump(chunk_data, f) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving chunks: {str(e)}") | |
| return False | |
| def load_chunks(self, session_id: str) -> tuple: | |
| """Load document chunks and their metadata. | |
| Args: | |
| session_id: Unique identifier for the session | |
| """ | |
| try: | |
| chunk_file = self.chunks_dir / f"{session_id}_chunks.pkl" | |
| if not chunk_file.exists(): | |
| return None, None | |
| with open(chunk_file, 'rb') as f: | |
| chunk_data = pickle.load(f) | |
| return chunk_data['chunks'], chunk_data['metadatas'] | |
| except Exception as e: | |
| st.error(f"Error loading chunks: {str(e)}") | |
| return None, None | |
| def list_available_sessions(self) -> List[Dict[str, Any]]: | |
| """List all available chat sessions with their metadata.""" | |
| try: | |
| sessions = [] | |
| for chat_file in self.chat_history_dir.glob("*.json"): | |
| with open(chat_file, 'r') as f: | |
| chat_data = json.load(f) | |
| session_id = chat_file.stem | |
| sessions.append({ | |
| 'session_id': session_id, | |
| 'last_updated': chat_data['last_updated'], | |
| 'metadata': chat_data['metadata'] | |
| }) | |
| # Sort by last updated time | |
| sessions.sort(key=lambda x: x['last_updated'], reverse=True) | |
| return sessions | |
| except Exception as e: | |
| st.error(f"Error listing sessions: {str(e)}") | |
| return [] |