ZunairaHawwar commited on
Commit
99d2ed6
·
verified ·
1 Parent(s): 2cc98aa

Update chat_manager.py

Browse files
Files changed (1) hide show
  1. chat_manager.py +334 -333
chat_manager.py CHANGED
@@ -1,334 +1,335 @@
1
- # chat_manager.py - Chat Session Management System
2
- import json
3
- import os
4
- from dataclasses import dataclass, asdict
5
- from typing import List, Optional, Dict, Any
6
- from datetime import datetime
7
- from pathlib import Path
8
-
9
- @dataclass
10
- class ChatMessage:
11
- """Individual chat message structure"""
12
- message_id: str
13
- role: str # 'user' or 'assistant'
14
- content: str
15
- timestamp: str
16
- rating: Optional[int] = None # 1 for thumbs up, -1 for thumbs down, None for no rating
17
- is_bookmarked: bool = False
18
- source_documents: List[str] = None
19
-
20
- def __post_init__(self):
21
- if self.source_documents is None:
22
- self.source_documents = []
23
-
24
- @dataclass
25
- class ChatSession:
26
- """Chat session data structure"""
27
- session_id: str
28
- user_id: str
29
- title: str
30
- created_at: str
31
- updated_at: str
32
- messages: List[ChatMessage] = None
33
- is_archived: bool = False
34
- tags: List[str] = None
35
-
36
- def __post_init__(self):
37
- if self.messages is None:
38
- self.messages = []
39
- if self.tags is None:
40
- self.tags = []
41
-
42
- class ChatManager:
43
- """Manages chat sessions and messages"""
44
-
45
- def __init__(self, data_dir: str):
46
- self.data_dir = Path(data_dir)
47
- self.data_dir.mkdir(exist_ok=True)
48
- self.sessions_file = self.data_dir / "sessions.json"
49
- self.ensure_sessions_file()
50
-
51
- def ensure_sessions_file(self):
52
- """Ensure sessions file exists"""
53
- if not self.sessions_file.exists():
54
- with open(self.sessions_file, 'w') as f:
55
- json.dump({}, f)
56
-
57
- def create_session(self, user_id: str, title: str = None) -> str:
58
- """Create a new chat session"""
59
- session_id = str(uuid.uuid4())
60
- timestamp = datetime.now().isoformat()
61
-
62
- if not title:
63
- title = f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}"
64
-
65
- session = ChatSession(
66
- session_id=session_id,
67
- user_id=user_id,
68
- title=title,
69
- created_at=timestamp,
70
- updated_at=timestamp
71
- )
72
-
73
- try:
74
- sessions = self.load_all_sessions()
75
- sessions[session_id] = asdict(session)
76
-
77
- with open(self.sessions_file, 'w') as f:
78
- json.dump(sessions, f, indent=2)
79
-
80
- return session_id
81
-
82
- except Exception as e:
83
- raise Exception(f"Failed to create session: {str(e)}")
84
-
85
- def load_all_sessions(self) -> Dict[str, Dict]:
86
- """Load all sessions from storage"""
87
- try:
88
- with open(self.sessions_file, 'r') as f:
89
- return json.load(f)
90
- except (FileNotFoundError, json.JSONDecodeError):
91
- return {}
92
-
93
- def get_session(self, session_id: str) -> Optional[ChatSession]:
94
- """Get chat session by ID"""
95
- sessions = self.load_all_sessions()
96
- session_data = sessions.get(session_id)
97
-
98
- if session_data:
99
- # Convert message dictionaries back to ChatMessage objects
100
- messages = []
101
- for msg_data in session_data.get('messages', []):
102
- messages.append(ChatMessage(**msg_data))
103
- session_data['messages'] = messages
104
- return ChatSession(**session_data)
105
- return None
106
-
107
- def get_user_sessions(self, user_id: str, include_archived: bool = False) -> List[ChatSession]:
108
- """Get all sessions for a user"""
109
- sessions = self.load_all_sessions()
110
- user_sessions = []
111
-
112
- for session_data in sessions.values():
113
- if session_data.get('user_id') == user_id:
114
- if include_archived or not session_data.get('is_archived', False):
115
- # Convert message dictionaries back to ChatMessage objects
116
- messages = []
117
- for msg_data in session_data.get('messages', []):
118
- messages.append(ChatMessage(**msg_data))
119
- session_data['messages'] = messages
120
- user_sessions.append(ChatSession(**session_data))
121
-
122
- # Sort by updated_at descending
123
- user_sessions.sort(key=lambda x: x.updated_at, reverse=True)
124
- return user_sessions
125
-
126
- def add_message(self, session_id: str, role: str, content: str, source_documents: List[str] = None) -> str:
127
- """Add a message to a chat session"""
128
- message_id = str(uuid.uuid4())
129
- timestamp = datetime.now().isoformat()
130
-
131
- message = ChatMessage(
132
- message_id=message_id,
133
- role=role,
134
- content=content,
135
- timestamp=timestamp,
136
- source_documents=source_documents or []
137
- )
138
-
139
- try:
140
- sessions = self.load_all_sessions()
141
-
142
- if session_id not in sessions:
143
- raise ValueError(f"Session {session_id} not found")
144
-
145
- # Convert message to dict for storage
146
- message_dict = asdict(message)
147
- sessions[session_id]['messages'].append(message_dict)
148
- sessions[session_id]['updated_at'] = timestamp
149
-
150
- with open(self.sessions_file, 'w') as f:
151
- json.dump(sessions, f, indent=2)
152
-
153
- return message_id
154
-
155
- except Exception as e:
156
- raise Exception(f"Failed to add message: {str(e)}")
157
-
158
- def rate_message(self, session_id: str, message_id: str, rating: int) -> bool:
159
- """Rate a message (1 for thumbs up, -1 for thumbs down)"""
160
- try:
161
- sessions = self.load_all_sessions()
162
-
163
- if session_id not in sessions:
164
- return False
165
-
166
- for message in sessions[session_id]['messages']:
167
- if message['message_id'] == message_id:
168
- message['rating'] = rating
169
- sessions[session_id]['updated_at'] = datetime.now().isoformat()
170
-
171
- with open(self.sessions_file, 'w') as f:
172
- json.dump(sessions, f, indent=2)
173
-
174
- return True
175
-
176
- return False
177
-
178
- except Exception:
179
- return False
180
-
181
- def bookmark_message(self, session_id: str, message_id: str, is_bookmarked: bool = True) -> bool:
182
- """Bookmark or unbookmark a message"""
183
- try:
184
- sessions = self.load_all_sessions()
185
-
186
- if session_id not in sessions:
187
- return False
188
-
189
- for message in sessions[session_id]['messages']:
190
- if message['message_id'] == message_id:
191
- message['is_bookmarked'] = is_bookmarked
192
- sessions[session_id]['updated_at'] = datetime.now().isoformat()
193
-
194
- with open(self.sessions_file, 'w') as f:
195
- json.dump(sessions, f, indent=2)
196
-
197
- return True
198
-
199
- return False
200
-
201
- except Exception:
202
- return False
203
-
204
- def get_bookmarked_messages(self, user_id: str) -> List[Dict[str, Any]]:
205
- """Get all bookmarked messages for a user"""
206
- sessions = self.load_all_sessions()
207
- bookmarked = []
208
-
209
- for session_data in sessions.values():
210
- if session_data.get('user_id') == user_id:
211
- for message in session_data.get('messages', []):
212
- if message.get('is_bookmarked', False):
213
- bookmarked.append({
214
- 'session_id': session_data['session_id'],
215
- 'session_title': session_data['title'],
216
- 'message': message,
217
- 'timestamp': message['timestamp']
218
- })
219
-
220
- # Sort by timestamp descending
221
- bookmarked.sort(key=lambda x: x['timestamp'], reverse=True)
222
- return bookmarked
223
-
224
- def update_session_title(self, session_id: str, title: str) -> bool:
225
- """Update session title"""
226
- try:
227
- sessions = self.load_all_sessions()
228
-
229
- if session_id not in sessions:
230
- return False
231
-
232
- sessions[session_id]['title'] = title
233
- sessions[session_id]['updated_at'] = datetime.now().isoformat()
234
-
235
- with open(self.sessions_file, 'w') as f:
236
- json.dump(sessions, f, indent=2)
237
-
238
- return True
239
-
240
- except Exception:
241
- return False
242
-
243
- def archive_session(self, session_id: str, is_archived: bool = True) -> bool:
244
- """Archive or unarchive a session"""
245
- try:
246
- sessions = self.load_all_sessions()
247
-
248
- if session_id not in sessions:
249
- return False
250
-
251
- sessions[session_id]['is_archived'] = is_archived
252
- sessions[session_id]['updated_at'] = datetime.now().isoformat()
253
-
254
- with open(self.sessions_file, 'w') as f:
255
- json.dump(sessions, f, indent=2)
256
-
257
- return True
258
-
259
- except Exception:
260
- return False
261
-
262
- def delete_session(self, session_id: str) -> bool:
263
- """Delete a chat session"""
264
- try:
265
- sessions = self.load_all_sessions()
266
-
267
- if session_id in sessions:
268
- del sessions[session_id]
269
-
270
- with open(self.sessions_file, 'w') as f:
271
- json.dump(sessions, f, indent=2)
272
-
273
- return True
274
- return False
275
-
276
- except Exception:
277
- return False
278
-
279
- def export_chat_history(self, user_id: str, session_id: str = None) -> Dict[str, Any]:
280
- """Export chat history for a user or specific session"""
281
- if session_id:
282
- session = self.get_session(session_id)
283
- if session and session.user_id == user_id:
284
- return {
285
- 'export_type': 'single_session',
286
- 'session': asdict(session),
287
- 'exported_at': datetime.now().isoformat()
288
- }
289
- else:
290
- sessions = self.get_user_sessions(user_id, include_archived=True)
291
- return {
292
- 'export_type': 'all_sessions',
293
- 'sessions': [asdict(session) for session in sessions],
294
- 'exported_at': datetime.now().isoformat(),
295
- 'total_sessions': len(sessions)
296
- }
297
-
298
- return {}
299
-
300
- def get_chat_statistics(self, user_id: str) -> Dict[str, Any]:
301
- """Get chat statistics for a user"""
302
- sessions = self.get_user_sessions(user_id, include_archived=True)
303
-
304
- total_messages = 0
305
- total_user_messages = 0
306
- total_assistant_messages = 0
307
- bookmarked_count = 0
308
- rated_messages = {'positive': 0, 'negative': 0}
309
-
310
- for session in sessions:
311
- total_messages += len(session.messages)
312
- for message in session.messages:
313
- if message.role == 'user':
314
- total_user_messages += 1
315
- else:
316
- total_assistant_messages += 1
317
-
318
- if message.is_bookmarked:
319
- bookmarked_count += 1
320
-
321
- if message.rating == 1:
322
- rated_messages['positive'] += 1
323
- elif message.rating == -1:
324
- rated_messages['negative'] += 1
325
-
326
- return {
327
- 'total_sessions': len(sessions),
328
- 'total_messages': total_messages,
329
- 'user_messages': total_user_messages,
330
- 'assistant_messages': total_assistant_messages,
331
- 'bookmarked_messages': bookmarked_count,
332
- 'message_ratings': rated_messages,
333
- 'average_messages_per_session': total_messages / len(sessions) if sessions else 0
 
334
  }
 
1
+ # chat_manager.py - Chat Session Management System
2
+ import json
3
+ import os
4
+ import uuid
5
+ from dataclasses import dataclass, asdict
6
+ from typing import List, Optional, Dict, Any
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+
10
+ @dataclass
11
+ class ChatMessage:
12
+ """Individual chat message structure"""
13
+ message_id: str
14
+ role: str # 'user' or 'assistant'
15
+ content: str
16
+ timestamp: str
17
+ rating: Optional[int] = None # 1 for thumbs up, -1 for thumbs down, None for no rating
18
+ is_bookmarked: bool = False
19
+ source_documents: List[str] = None
20
+
21
+ def __post_init__(self):
22
+ if self.source_documents is None:
23
+ self.source_documents = []
24
+
25
+ @dataclass
26
+ class ChatSession:
27
+ """Chat session data structure"""
28
+ session_id: str
29
+ user_id: str
30
+ title: str
31
+ created_at: str
32
+ updated_at: str
33
+ messages: List[ChatMessage] = None
34
+ is_archived: bool = False
35
+ tags: List[str] = None
36
+
37
+ def __post_init__(self):
38
+ if self.messages is None:
39
+ self.messages = []
40
+ if self.tags is None:
41
+ self.tags = []
42
+
43
+ class ChatManager:
44
+ """Manages chat sessions and messages"""
45
+
46
+ def __init__(self, data_dir: str):
47
+ self.data_dir = Path(data_dir)
48
+ self.data_dir.mkdir(exist_ok=True)
49
+ self.sessions_file = self.data_dir / "sessions.json"
50
+ self.ensure_sessions_file()
51
+
52
+ def ensure_sessions_file(self):
53
+ """Ensure sessions file exists"""
54
+ if not self.sessions_file.exists():
55
+ with open(self.sessions_file, 'w') as f:
56
+ json.dump({}, f)
57
+
58
+ def create_session(self, user_id: str, title: str = None) -> str:
59
+ """Create a new chat session"""
60
+ session_id = str(uuid.uuid4())
61
+ timestamp = datetime.now().isoformat()
62
+
63
+ if not title:
64
+ title = f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}"
65
+
66
+ session = ChatSession(
67
+ session_id=session_id,
68
+ user_id=user_id,
69
+ title=title,
70
+ created_at=timestamp,
71
+ updated_at=timestamp
72
+ )
73
+
74
+ try:
75
+ sessions = self.load_all_sessions()
76
+ sessions[session_id] = asdict(session)
77
+
78
+ with open(self.sessions_file, 'w') as f:
79
+ json.dump(sessions, f, indent=2)
80
+
81
+ return session_id
82
+
83
+ except Exception as e:
84
+ raise Exception(f"Failed to create session: {str(e)}")
85
+
86
+ def load_all_sessions(self) -> Dict[str, Dict]:
87
+ """Load all sessions from storage"""
88
+ try:
89
+ with open(self.sessions_file, 'r') as f:
90
+ return json.load(f)
91
+ except (FileNotFoundError, json.JSONDecodeError):
92
+ return {}
93
+
94
+ def get_session(self, session_id: str) -> Optional[ChatSession]:
95
+ """Get chat session by ID"""
96
+ sessions = self.load_all_sessions()
97
+ session_data = sessions.get(session_id)
98
+
99
+ if session_data:
100
+ # Convert message dictionaries back to ChatMessage objects
101
+ messages = []
102
+ for msg_data in session_data.get('messages', []):
103
+ messages.append(ChatMessage(**msg_data))
104
+ session_data['messages'] = messages
105
+ return ChatSession(**session_data)
106
+ return None
107
+
108
+ def get_user_sessions(self, user_id: str, include_archived: bool = False) -> List[ChatSession]:
109
+ """Get all sessions for a user"""
110
+ sessions = self.load_all_sessions()
111
+ user_sessions = []
112
+
113
+ for session_data in sessions.values():
114
+ if session_data.get('user_id') == user_id:
115
+ if include_archived or not session_data.get('is_archived', False):
116
+ # Convert message dictionaries back to ChatMessage objects
117
+ messages = []
118
+ for msg_data in session_data.get('messages', []):
119
+ messages.append(ChatMessage(**msg_data))
120
+ session_data['messages'] = messages
121
+ user_sessions.append(ChatSession(**session_data))
122
+
123
+ # Sort by updated_at descending
124
+ user_sessions.sort(key=lambda x: x.updated_at, reverse=True)
125
+ return user_sessions
126
+
127
+ def add_message(self, session_id: str, role: str, content: str, source_documents: List[str] = None) -> str:
128
+ """Add a message to a chat session"""
129
+ message_id = str(uuid.uuid4())
130
+ timestamp = datetime.now().isoformat()
131
+
132
+ message = ChatMessage(
133
+ message_id=message_id,
134
+ role=role,
135
+ content=content,
136
+ timestamp=timestamp,
137
+ source_documents=source_documents or []
138
+ )
139
+
140
+ try:
141
+ sessions = self.load_all_sessions()
142
+
143
+ if session_id not in sessions:
144
+ raise ValueError(f"Session {session_id} not found")
145
+
146
+ # Convert message to dict for storage
147
+ message_dict = asdict(message)
148
+ sessions[session_id]['messages'].append(message_dict)
149
+ sessions[session_id]['updated_at'] = timestamp
150
+
151
+ with open(self.sessions_file, 'w') as f:
152
+ json.dump(sessions, f, indent=2)
153
+
154
+ return message_id
155
+
156
+ except Exception as e:
157
+ raise Exception(f"Failed to add message: {str(e)}")
158
+
159
+ def rate_message(self, session_id: str, message_id: str, rating: int) -> bool:
160
+ """Rate a message (1 for thumbs up, -1 for thumbs down)"""
161
+ try:
162
+ sessions = self.load_all_sessions()
163
+
164
+ if session_id not in sessions:
165
+ return False
166
+
167
+ for message in sessions[session_id]['messages']:
168
+ if message['message_id'] == message_id:
169
+ message['rating'] = rating
170
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
171
+
172
+ with open(self.sessions_file, 'w') as f:
173
+ json.dump(sessions, f, indent=2)
174
+
175
+ return True
176
+
177
+ return False
178
+
179
+ except Exception:
180
+ return False
181
+
182
+ def bookmark_message(self, session_id: str, message_id: str, is_bookmarked: bool = True) -> bool:
183
+ """Bookmark or unbookmark a message"""
184
+ try:
185
+ sessions = self.load_all_sessions()
186
+
187
+ if session_id not in sessions:
188
+ return False
189
+
190
+ for message in sessions[session_id]['messages']:
191
+ if message['message_id'] == message_id:
192
+ message['is_bookmarked'] = is_bookmarked
193
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
194
+
195
+ with open(self.sessions_file, 'w') as f:
196
+ json.dump(sessions, f, indent=2)
197
+
198
+ return True
199
+
200
+ return False
201
+
202
+ except Exception:
203
+ return False
204
+
205
+ def get_bookmarked_messages(self, user_id: str) -> List[Dict[str, Any]]:
206
+ """Get all bookmarked messages for a user"""
207
+ sessions = self.load_all_sessions()
208
+ bookmarked = []
209
+
210
+ for session_data in sessions.values():
211
+ if session_data.get('user_id') == user_id:
212
+ for message in session_data.get('messages', []):
213
+ if message.get('is_bookmarked', False):
214
+ bookmarked.append({
215
+ 'session_id': session_data['session_id'],
216
+ 'session_title': session_data['title'],
217
+ 'message': message,
218
+ 'timestamp': message['timestamp']
219
+ })
220
+
221
+ # Sort by timestamp descending
222
+ bookmarked.sort(key=lambda x: x['timestamp'], reverse=True)
223
+ return bookmarked
224
+
225
+ def update_session_title(self, session_id: str, title: str) -> bool:
226
+ """Update session title"""
227
+ try:
228
+ sessions = self.load_all_sessions()
229
+
230
+ if session_id not in sessions:
231
+ return False
232
+
233
+ sessions[session_id]['title'] = title
234
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
235
+
236
+ with open(self.sessions_file, 'w') as f:
237
+ json.dump(sessions, f, indent=2)
238
+
239
+ return True
240
+
241
+ except Exception:
242
+ return False
243
+
244
+ def archive_session(self, session_id: str, is_archived: bool = True) -> bool:
245
+ """Archive or unarchive a session"""
246
+ try:
247
+ sessions = self.load_all_sessions()
248
+
249
+ if session_id not in sessions:
250
+ return False
251
+
252
+ sessions[session_id]['is_archived'] = is_archived
253
+ sessions[session_id]['updated_at'] = datetime.now().isoformat()
254
+
255
+ with open(self.sessions_file, 'w') as f:
256
+ json.dump(sessions, f, indent=2)
257
+
258
+ return True
259
+
260
+ except Exception:
261
+ return False
262
+
263
+ def delete_session(self, session_id: str) -> bool:
264
+ """Delete a chat session"""
265
+ try:
266
+ sessions = self.load_all_sessions()
267
+
268
+ if session_id in sessions:
269
+ del sessions[session_id]
270
+
271
+ with open(self.sessions_file, 'w') as f:
272
+ json.dump(sessions, f, indent=2)
273
+
274
+ return True
275
+ return False
276
+
277
+ except Exception:
278
+ return False
279
+
280
+ def export_chat_history(self, user_id: str, session_id: str = None) -> Dict[str, Any]:
281
+ """Export chat history for a user or specific session"""
282
+ if session_id:
283
+ session = self.get_session(session_id)
284
+ if session and session.user_id == user_id:
285
+ return {
286
+ 'export_type': 'single_session',
287
+ 'session': asdict(session),
288
+ 'exported_at': datetime.now().isoformat()
289
+ }
290
+ else:
291
+ sessions = self.get_user_sessions(user_id, include_archived=True)
292
+ return {
293
+ 'export_type': 'all_sessions',
294
+ 'sessions': [asdict(session) for session in sessions],
295
+ 'exported_at': datetime.now().isoformat(),
296
+ 'total_sessions': len(sessions)
297
+ }
298
+
299
+ return {}
300
+
301
+ def get_chat_statistics(self, user_id: str) -> Dict[str, Any]:
302
+ """Get chat statistics for a user"""
303
+ sessions = self.get_user_sessions(user_id, include_archived=True)
304
+
305
+ total_messages = 0
306
+ total_user_messages = 0
307
+ total_assistant_messages = 0
308
+ bookmarked_count = 0
309
+ rated_messages = {'positive': 0, 'negative': 0}
310
+
311
+ for session in sessions:
312
+ total_messages += len(session.messages)
313
+ for message in session.messages:
314
+ if message.role == 'user':
315
+ total_user_messages += 1
316
+ else:
317
+ total_assistant_messages += 1
318
+
319
+ if message.is_bookmarked:
320
+ bookmarked_count += 1
321
+
322
+ if message.rating == 1:
323
+ rated_messages['positive'] += 1
324
+ elif message.rating == -1:
325
+ rated_messages['negative'] += 1
326
+
327
+ return {
328
+ 'total_sessions': len(sessions),
329
+ 'total_messages': total_messages,
330
+ 'user_messages': total_user_messages,
331
+ 'assistant_messages': total_assistant_messages,
332
+ 'bookmarked_messages': bookmarked_count,
333
+ 'message_ratings': rated_messages,
334
+ 'average_messages_per_session': total_messages / len(sessions) if sessions else 0
335
  }