ZunairaHawwar commited on
Commit
a818a2d
·
verified ·
1 Parent(s): 081a454

Upload chat_manager.py

Browse files
Files changed (1) hide show
  1. chat_manager.py +334 -0
chat_manager.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat_manager.py - Chat Session Management System
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass, asdict
5
+ from typing import List, Optional, Dict, Any
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ @dataclass
10
+ class ChatMessage:
11
+ """Individual chat message structure"""
12
+ message_id: str
13
+ role: str # 'user' or 'assistant'
14
+ content: str
15
+ timestamp: str
16
+ rating: Optional[int] = None # 1 for thumbs up, -1 for thumbs down, None for no rating
17
+ is_bookmarked: bool = False
18
+ source_documents: List[str] = None
19
+
20
+ def __post_init__(self):
21
+ if self.source_documents is None:
22
+ self.source_documents = []
23
+
24
+ @dataclass
25
+ class ChatSession:
26
+ """Chat session data structure"""
27
+ session_id: str
28
+ user_id: str
29
+ title: str
30
+ created_at: str
31
+ updated_at: str
32
+ messages: List[ChatMessage] = None
33
+ is_archived: bool = False
34
+ tags: List[str] = None
35
+
36
+ def __post_init__(self):
37
+ if self.messages is None:
38
+ self.messages = []
39
+ if self.tags is None:
40
+ self.tags = []
41
+
42
+ class ChatManager:
43
+ """Manages chat sessions and messages"""
44
+
45
+ def __init__(self, data_dir: str):
46
+ self.data_dir = Path(data_dir)
47
+ self.data_dir.mkdir(exist_ok=True)
48
+ self.sessions_file = self.data_dir / "sessions.json"
49
+ self.ensure_sessions_file()
50
+
51
+ def ensure_sessions_file(self):
52
+ """Ensure sessions file exists"""
53
+ if not self.sessions_file.exists():
54
+ with open(self.sessions_file, 'w') as f:
55
+ json.dump({}, f)
56
+
57
+ def create_session(self, user_id: str, title: str = None) -> str:
58
+ """Create a new chat session"""
59
+ session_id = str(uuid.uuid4())
60
+ timestamp = datetime.now().isoformat()
61
+
62
+ if not title:
63
+ title = f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}"
64
+
65
+ session = ChatSession(
66
+ session_id=session_id,
67
+ user_id=user_id,
68
+ title=title,
69
+ created_at=timestamp,
70
+ updated_at=timestamp
71
+ )
72
+
73
+ try:
74
+ sessions = self.load_all_sessions()
75
+ sessions[session_id] = asdict(session)
76
+
77
+ with open(self.sessions_file, 'w') as f:
78
+ json.dump(sessions, f, indent=2)
79
+
80
+ return session_id
81
+
82
+ except Exception as e:
83
+ raise Exception(f"Failed to create session: {str(e)}")
84
+
85
+ def load_all_sessions(self) -> Dict[str, Dict]:
86
+ """Load all sessions from storage"""
87
+ try:
88
+ with open(self.sessions_file, 'r') as f:
89
+ return json.load(f)
90
+ except (FileNotFoundError, json.JSONDecodeError):
91
+ return {}
92
+
93
+ def get_session(self, session_id: str) -> Optional[ChatSession]:
94
+ """Get chat session by ID"""
95
+ sessions = self.load_all_sessions()
96
+ session_data = sessions.get(session_id)
97
+
98
+ if session_data:
99
+ # Convert message dictionaries back to ChatMessage objects
100
+ messages = []
101
+ for msg_data in session_data.get('messages', []):
102
+ messages.append(ChatMessage(**msg_data))
103
+ session_data['messages'] = messages
104
+ return ChatSession(**session_data)
105
+ return None
106
+
107
+ def get_user_sessions(self, user_id: str, include_archived: bool = False) -> List[ChatSession]:
108
+ """Get all sessions for a user"""
109
+ sessions = self.load_all_sessions()
110
+ user_sessions = []
111
+
112
+ for session_data in sessions.values():
113
+ if session_data.get('user_id') == user_id:
114
+ if include_archived or not session_data.get('is_archived', False):
115
+ # Convert message dictionaries back to ChatMessage objects
116
+ messages = []
117
+ for msg_data in session_data.get('messages', []):
118
+ messages.append(ChatMessage(**msg_data))
119
+ session_data['messages'] = messages
120
+ user_sessions.append(ChatSession(**session_data))
121
+
122
+ # Sort by updated_at descending
123
+ user_sessions.sort(key=lambda x: x.updated_at, reverse=True)
124
+ return user_sessions
125
+
126
+ def add_message(self, session_id: str, role: str, content: str, source_documents: List[str] = None) -> str:
127
+ """Add a message to a chat session"""
128
+ message_id = str(uuid.uuid4())
129
+ timestamp = datetime.now().isoformat()
130
+
131
+ message = ChatMessage(
132
+ message_id=message_id,
133
+ role=role,
134
+ content=content,
135
+ timestamp=timestamp,
136
+ source_documents=source_documents or []
137
+ )
138
+
139
+ try:
140
+ sessions = self.load_all_sessions()
141
+
142
+ if session_id not in sessions:
143
+ raise ValueError(f"Session {session_id} not found")
144
+
145
+ # Convert message to dict for storage
146
+ message_dict = asdict(message)
147
+ sessions[session_id]['messages'].append(message_dict)
148
+ sessions[session_id]['updated_at'] = timestamp
149
+
150
+ with open(self.sessions_file, 'w') as f:
151
+ json.dump(sessions, f, indent=2)
152
+
153
+ return message_id
154
+
155
+ except Exception as e:
156
+ raise Exception(f"Failed to add message: {str(e)}")
157
+
158
+ def rate_message(self, session_id: str, message_id: str, rating: int) -> bool:
159
+ """Rate a message (1 for thumbs up, -1 for thumbs down)"""
160
+ try:
161
+ sessions = self.load_all_sessions()
162
+
163
+ if session_id not in sessions:
164
+ return False
165
+
166
+ for message in sessions[session_id]['messages']:
167
+ if message['message_id'] == message_id:
168
+ message['rating'] = rating
169
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
170
+
171
+ with open(self.sessions_file, 'w') as f:
172
+ json.dump(sessions, f, indent=2)
173
+
174
+ return True
175
+
176
+ return False
177
+
178
+ except Exception:
179
+ return False
180
+
181
+ def bookmark_message(self, session_id: str, message_id: str, is_bookmarked: bool = True) -> bool:
182
+ """Bookmark or unbookmark a message"""
183
+ try:
184
+ sessions = self.load_all_sessions()
185
+
186
+ if session_id not in sessions:
187
+ return False
188
+
189
+ for message in sessions[session_id]['messages']:
190
+ if message['message_id'] == message_id:
191
+ message['is_bookmarked'] = is_bookmarked
192
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
193
+
194
+ with open(self.sessions_file, 'w') as f:
195
+ json.dump(sessions, f, indent=2)
196
+
197
+ return True
198
+
199
+ return False
200
+
201
+ except Exception:
202
+ return False
203
+
204
+ def get_bookmarked_messages(self, user_id: str) -> List[Dict[str, Any]]:
205
+ """Get all bookmarked messages for a user"""
206
+ sessions = self.load_all_sessions()
207
+ bookmarked = []
208
+
209
+ for session_data in sessions.values():
210
+ if session_data.get('user_id') == user_id:
211
+ for message in session_data.get('messages', []):
212
+ if message.get('is_bookmarked', False):
213
+ bookmarked.append({
214
+ 'session_id': session_data['session_id'],
215
+ 'session_title': session_data['title'],
216
+ 'message': message,
217
+ 'timestamp': message['timestamp']
218
+ })
219
+
220
+ # Sort by timestamp descending
221
+ bookmarked.sort(key=lambda x: x['timestamp'], reverse=True)
222
+ return bookmarked
223
+
224
+ def update_session_title(self, session_id: str, title: str) -> bool:
225
+ """Update session title"""
226
+ try:
227
+ sessions = self.load_all_sessions()
228
+
229
+ if session_id not in sessions:
230
+ return False
231
+
232
+ sessions[session_id]['title'] = title
233
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
234
+
235
+ with open(self.sessions_file, 'w') as f:
236
+ json.dump(sessions, f, indent=2)
237
+
238
+ return True
239
+
240
+ except Exception:
241
+ return False
242
+
243
+ def archive_session(self, session_id: str, is_archived: bool = True) -> bool:
244
+ """Archive or unarchive a session"""
245
+ try:
246
+ sessions = self.load_all_sessions()
247
+
248
+ if session_id not in sessions:
249
+ return False
250
+
251
+ sessions[session_id]['is_archived'] = is_archived
252
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
253
+
254
+ with open(self.sessions_file, 'w') as f:
255
+ json.dump(sessions, f, indent=2)
256
+
257
+ return True
258
+
259
+ except Exception:
260
+ return False
261
+
262
+ def delete_session(self, session_id: str) -> bool:
263
+ """Delete a chat session"""
264
+ try:
265
+ sessions = self.load_all_sessions()
266
+
267
+ if session_id in sessions:
268
+ del sessions[session_id]
269
+
270
+ with open(self.sessions_file, 'w') as f:
271
+ json.dump(sessions, f, indent=2)
272
+
273
+ return True
274
+ return False
275
+
276
+ except Exception:
277
+ return False
278
+
279
+ def export_chat_history(self, user_id: str, session_id: str = None) -> Dict[str, Any]:
280
+ """Export chat history for a user or specific session"""
281
+ if session_id:
282
+ session = self.get_session(session_id)
283
+ if session and session.user_id == user_id:
284
+ return {
285
+ 'export_type': 'single_session',
286
+ 'session': asdict(session),
287
+ 'exported_at': datetime.now().isoformat()
288
+ }
289
+ else:
290
+ sessions = self.get_user_sessions(user_id, include_archived=True)
291
+ return {
292
+ 'export_type': 'all_sessions',
293
+ 'sessions': [asdict(session) for session in sessions],
294
+ 'exported_at': datetime.now().isoformat(),
295
+ 'total_sessions': len(sessions)
296
+ }
297
+
298
+ return {}
299
+
300
+ def get_chat_statistics(self, user_id: str) -> Dict[str, Any]:
301
+ """Get chat statistics for a user"""
302
+ sessions = self.get_user_sessions(user_id, include_archived=True)
303
+
304
+ total_messages = 0
305
+ total_user_messages = 0
306
+ total_assistant_messages = 0
307
+ bookmarked_count = 0
308
+ rated_messages = {'positive': 0, 'negative': 0}
309
+
310
+ for session in sessions:
311
+ total_messages += len(session.messages)
312
+ for message in session.messages:
313
+ if message.role == 'user':
314
+ total_user_messages += 1
315
+ else:
316
+ total_assistant_messages += 1
317
+
318
+ if message.is_bookmarked:
319
+ bookmarked_count += 1
320
+
321
+ if message.rating == 1:
322
+ rated_messages['positive'] += 1
323
+ elif message.rating == -1:
324
+ rated_messages['negative'] += 1
325
+
326
+ return {
327
+ 'total_sessions': len(sessions),
328
+ 'total_messages': total_messages,
329
+ 'user_messages': total_user_messages,
330
+ 'assistant_messages': total_assistant_messages,
331
+ 'bookmarked_messages': bookmarked_count,
332
+ 'message_ratings': rated_messages,
333
+ 'average_messages_per_session': total_messages / len(sessions) if sessions else 0
334
+ }