cryogenic22 commited on
Commit
b59e212
·
verified ·
1 Parent(s): 69ee452

Create persistence.py

Browse files
Files changed (1) hide show
  1. utils/persistence.py +221 -0
utils/persistence.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from datetime import datetime
4
+ import faiss
5
+ import numpy as np
6
+ import pickle
7
+ from pathlib import Path
8
+ import streamlit as st
9
+ from typing import List, Dict, Any
10
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
11
+
12
+ class PersistenceManager:
13
+ def __init__(self, data_dir: str = "data"):
14
+ """Initialize the persistence manager with paths for storing data.
15
+
16
+ Args:
17
+ data_dir: Base directory for data storage
18
+ """
19
+ self.base_dir = Path(data_dir)
20
+ self.vector_store_dir = self.base_dir / "vector_stores"
21
+ self.chat_history_dir = self.base_dir / "chat_histories"
22
+ self.chunks_dir = self.base_dir / "chunks"
23
+
24
+ # Create necessary directories
25
+ for directory in [self.vector_store_dir, self.chat_history_dir, self.chunks_dir]:
26
+ directory.mkdir(parents=True, exist_ok=True)
27
+
28
+ def save_vector_store(self, vector_store: Any, session_id: str) -> bool:
29
+ """Save FAISS vector store and related metadata.
30
+
31
+ Args:
32
+ vector_store: FAISS vector store instance
33
+ session_id: Unique identifier for the session
34
+ """
35
+ try:
36
+ # Create session-specific directory
37
+ store_path = self.vector_store_dir / session_id
38
+ store_path.mkdir(exist_ok=True)
39
+
40
+ # Save the FAISS index
41
+ faiss.write_index(vector_store.index,
42
+ str(store_path / "index.faiss"))
43
+
44
+ # Save the documents and their metadata
45
+ with open(store_path / "docstore.pkl", "wb") as f:
46
+ pickle.dump(vector_store.docstore, f)
47
+
48
+ return True
49
+ except Exception as e:
50
+ st.error(f"Error saving vector store: {str(e)}")
51
+ return False
52
+
53
+ def load_vector_store(self, session_id: str) -> Any:
54
+ """Load FAISS vector store and related metadata.
55
+
56
+ Args:
57
+ session_id: Unique identifier for the session
58
+ """
59
+ try:
60
+ store_path = self.vector_store_dir / session_id
61
+ if not store_path.exists():
62
+ return None
63
+
64
+ # Load the FAISS index
65
+ index = faiss.read_index(str(store_path / "index.faiss"))
66
+
67
+ # Load the document store
68
+ with open(store_path / "docstore.pkl", "rb") as f:
69
+ docstore = pickle.load(f)
70
+
71
+ # Recreate the vector store
72
+ from langchain.vectorstores import FAISS
73
+ vector_store = FAISS(
74
+ embedding_function=st.session_state.embeddings,
75
+ index=index,
76
+ docstore=docstore,
77
+ index_to_docstore_id=docstore.index_to_docstore_id
78
+ )
79
+
80
+ return vector_store
81
+ except Exception as e:
82
+ st.error(f"Error loading vector store: {str(e)}")
83
+ return None
84
+
85
+ def save_chat_history(
86
+ self,
87
+ messages: List[BaseMessage],
88
+ session_id: str,
89
+ metadata: Dict[str, Any] = None
90
+ ) -> bool:
91
+ """Save chat history with metadata.
92
+
93
+ Args:
94
+ messages: List of chat messages
95
+ session_id: Unique identifier for the chat session
96
+ metadata: Additional metadata about the chat session
97
+ """
98
+ try:
99
+ # Convert messages to serializable format
100
+ serialized_messages = []
101
+ for msg in messages:
102
+ if isinstance(msg, (HumanMessage, AIMessage)):
103
+ serialized_messages.append({
104
+ 'type': msg.__class__.__name__,
105
+ 'content': msg.content,
106
+ 'timestamp': datetime.now().isoformat()
107
+ })
108
+
109
+ # Prepare chat data
110
+ chat_data = {
111
+ 'messages': serialized_messages,
112
+ 'metadata': metadata or {},
113
+ 'last_updated': datetime.now().isoformat()
114
+ }
115
+
116
+ # Save to JSON file
117
+ chat_file = self.chat_history_dir / f"{session_id}.json"
118
+ with open(chat_file, 'w') as f:
119
+ json.dump(chat_data, f, indent=2)
120
+
121
+ return True
122
+ except Exception as e:
123
+ st.error(f"Error saving chat history: {str(e)}")
124
+ return False
125
+
126
+ def load_chat_history(self, session_id: str) -> List[BaseMessage]:
127
+ """Load chat history for a session.
128
+
129
+ Args:
130
+ session_id: Unique identifier for the chat session
131
+ """
132
+ try:
133
+ chat_file = self.chat_history_dir / f"{session_id}.json"
134
+ if not chat_file.exists():
135
+ return []
136
+
137
+ with open(chat_file, 'r') as f:
138
+ chat_data = json.load(f)
139
+
140
+ # Convert back to message objects
141
+ messages = []
142
+ for msg in chat_data['messages']:
143
+ if msg['type'] == 'HumanMessage':
144
+ messages.append(HumanMessage(content=msg['content']))
145
+ elif msg['type'] == 'AIMessage':
146
+ messages.append(AIMessage(content=msg['content']))
147
+
148
+ return messages
149
+ except Exception as e:
150
+ st.error(f"Error loading chat history: {str(e)}")
151
+ return []
152
+
153
+ def save_chunks(
154
+ self,
155
+ chunks: List[str],
156
+ chunk_metadatas: List[Dict],
157
+ session_id: str
158
+ ) -> bool:
159
+ """Save document chunks and their metadata.
160
+
161
+ Args:
162
+ chunks: List of text chunks
163
+ chunk_metadatas: List of metadata dictionaries for each chunk
164
+ session_id: Unique identifier for the session
165
+ """
166
+ try:
167
+ chunk_data = {
168
+ 'chunks': chunks,
169
+ 'metadatas': chunk_metadatas,
170
+ 'created_at': datetime.now().isoformat()
171
+ }
172
+
173
+ chunk_file = self.chunks_dir / f"{session_id}_chunks.pkl"
174
+ with open(chunk_file, 'wb') as f:
175
+ pickle.dump(chunk_data, f)
176
+
177
+ return True
178
+ except Exception as e:
179
+ st.error(f"Error saving chunks: {str(e)}")
180
+ return False
181
+
182
+ def load_chunks(self, session_id: str) -> tuple:
183
+ """Load document chunks and their metadata.
184
+
185
+ Args:
186
+ session_id: Unique identifier for the session
187
+ """
188
+ try:
189
+ chunk_file = self.chunks_dir / f"{session_id}_chunks.pkl"
190
+ if not chunk_file.exists():
191
+ return None, None
192
+
193
+ with open(chunk_file, 'rb') as f:
194
+ chunk_data = pickle.load(f)
195
+
196
+ return chunk_data['chunks'], chunk_data['metadatas']
197
+ except Exception as e:
198
+ st.error(f"Error loading chunks: {str(e)}")
199
+ return None, None
200
+
201
+ def list_available_sessions(self) -> List[Dict[str, Any]]:
202
+ """List all available chat sessions with their metadata."""
203
+ try:
204
+ sessions = []
205
+ for chat_file in self.chat_history_dir.glob("*.json"):
206
+ with open(chat_file, 'r') as f:
207
+ chat_data = json.load(f)
208
+
209
+ session_id = chat_file.stem
210
+ sessions.append({
211
+ 'session_id': session_id,
212
+ 'last_updated': chat_data['last_updated'],
213
+ 'metadata': chat_data['metadata']
214
+ })
215
+
216
+ # Sort by last updated time
217
+ sessions.sort(key=lambda x: x['last_updated'], reverse=True)
218
+ return sessions
219
+ except Exception as e:
220
+ st.error(f"Error listing sessions: {str(e)}")
221
+ return []