cryogenic22 commited on
Commit
219a4c2
·
verified ·
1 Parent(s): 2260e72

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +351 -0
utils/database.py CHANGED
@@ -364,6 +364,357 @@ def add_query(conn: sqlite3.Connection, query: str, response: str, document_id:
364
  st.error(f"Error adding query: {e}")
365
  return False
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  def add_annotation(conn: sqlite3.Connection, document_id: int, annotation: str, page_number: Optional[int] = None) -> bool:
368
  """Add an annotation to a document."""
369
  try:
 
364
  st.error(f"Error adding query: {e}")
365
  return False
366
 
367
+ # Add to utils/database.py
368
+
369
+ import sqlite3
370
+ from typing import List, Dict, Optional
371
+ from datetime import datetime
372
+ from langchain_core.messages import HumanMessage, AIMessage
373
+ import streamlit as st
374
+
375
+ def create_chat_tables(conn: sqlite3.Connection) -> None:
376
+ """Create necessary tables for chat management."""
377
+ try:
378
+ with conn_lock:
379
+ cursor = conn.cursor()
380
+
381
+ # Create chats table
382
+ cursor.execute('''
383
+ CREATE TABLE IF NOT EXISTS chats (
384
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
385
+ title TEXT NOT NULL,
386
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
387
+ last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
388
+ collection_id INTEGER,
389
+ FOREIGN KEY (collection_id) REFERENCES collections (id) ON DELETE SET NULL
390
+ )
391
+ ''')
392
+
393
+ # Create chat messages table
394
+ cursor.execute('''
395
+ CREATE TABLE IF NOT EXISTS chat_messages (
396
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
397
+ chat_id INTEGER NOT NULL,
398
+ role TEXT NOT NULL,
399
+ content TEXT NOT NULL,
400
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
401
+ metadata TEXT, -- Store metadata as JSON string
402
+ FOREIGN KEY (chat_id) REFERENCES chats (id) ON DELETE CASCADE
403
+ )
404
+ ''')
405
+
406
+ conn.commit()
407
+
408
+ except sqlite3.Error as e:
409
+ st.error(f"Error creating chat tables: {e}")
410
+ raise
411
+
412
+ def create_new_chat(conn: sqlite3.Connection, title: str, collection_id: Optional[int] = None) -> Optional[int]:
413
+ """Create a new chat session."""
414
+ try:
415
+ with conn_lock:
416
+ cursor = conn.cursor()
417
+ cursor.execute('''
418
+ INSERT INTO chats (title, collection_id, created_at, last_updated)
419
+ VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
420
+ ''', (title, collection_id))
421
+
422
+ conn.commit()
423
+ return cursor.lastrowid
424
+
425
+ except sqlite3.Error as e:
426
+ st.error(f"Error creating new chat: {e}")
427
+ return None
428
+
429
+ def save_chat_message(conn: sqlite3.Connection,
430
+ chat_id: int,
431
+ role: str,
432
+ content: str,
433
+ metadata: Optional[Dict] = None) -> Optional[int]:
434
+ """Save a chat message to the database."""
435
+ try:
436
+ with conn_lock:
437
+ cursor = conn.cursor()
438
+
439
+ # Convert metadata to JSON string if present
440
+ metadata_str = json.dumps(metadata) if metadata else None
441
+
442
+ # Insert message
443
+ cursor.execute('''
444
+ INSERT INTO chat_messages (chat_id, role, content, metadata, timestamp)
445
+ VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
446
+ ''', (chat_id, role, content, metadata_str))
447
+
448
+ # Update chat last_updated timestamp
449
+ cursor.execute('''
450
+ UPDATE chats
451
+ SET last_updated = CURRENT_TIMESTAMP
452
+ WHERE id = ?
453
+ ''', (chat_id,))
454
+
455
+ conn.commit()
456
+ return cursor.lastrowid
457
+
458
+ except sqlite3.Error as e:
459
+ st.error(f"Error saving chat message: {e}")
460
+ return None
461
+
462
+ def get_all_chats(conn: sqlite3.Connection) -> List[Dict]:
463
+ """Retrieve all chat sessions."""
464
+ try:
465
+ with conn_lock:
466
+ cursor = conn.cursor()
467
+ cursor.execute('''
468
+ SELECT
469
+ c.id,
470
+ c.title,
471
+ c.created_at,
472
+ c.last_updated,
473
+ c.collection_id,
474
+ COUNT(m.id) as message_count,
475
+ MAX(m.timestamp) as last_message
476
+ FROM chats c
477
+ LEFT JOIN chat_messages m ON c.id = m.chat_id
478
+ GROUP BY c.id
479
+ ORDER BY c.last_updated DESC
480
+ ''')
481
+
482
+ chats = []
483
+ for row in cursor.fetchall():
484
+ chats.append({
485
+ 'id': row[0],
486
+ 'title': row[1],
487
+ 'created_at': row[2],
488
+ 'last_updated': row[3],
489
+ 'collection_id': row[4],
490
+ 'message_count': row[5],
491
+ 'last_message': row[6]
492
+ })
493
+ return chats
494
+
495
+ except sqlite3.Error as e:
496
+ st.error(f"Error retrieving chats: {e}")
497
+ return []
498
+
499
+ def get_chat_messages(conn: sqlite3.Connection, chat_id: int) -> List[Dict]:
500
+ """Retrieve all messages for a specific chat."""
501
+ try:
502
+ with conn_lock:
503
+ cursor = conn.cursor()
504
+ cursor.execute('''
505
+ SELECT id, role, content, metadata, timestamp
506
+ FROM chat_messages
507
+ WHERE chat_id = ?
508
+ ORDER BY timestamp
509
+ ''', (chat_id,))
510
+
511
+ messages = []
512
+ for row in cursor.fetchall():
513
+ # Parse metadata JSON if present
514
+ metadata = json.loads(row[3]) if row[3] else None
515
+
516
+ # Convert to appropriate message type
517
+ if row[1] == 'human':
518
+ message = HumanMessage(content=row[2])
519
+ else:
520
+ message = AIMessage(content=row[2], additional_kwargs={'metadata': metadata})
521
+
522
+ messages.append(message)
523
+
524
+ return messages
525
+
526
+ except sqlite3.Error as e:
527
+ st.error(f"Error retrieving chat messages: {e}")
528
+ return []
529
+
530
+ def delete_chat(conn: sqlite3.Connection, chat_id: int) -> bool:
531
+ """Delete a chat session and all its messages."""
532
+ try:
533
+ with conn_lock:
534
+ cursor = conn.cursor()
535
+ # Messages will be automatically deleted due to CASCADE
536
+ cursor.execute('DELETE FROM chats WHERE id = ?', (chat_id,))
537
+ conn.commit()
538
+ return True
539
+
540
+ except sqlite3.Error as e:
541
+ st.error(f"Error deleting chat: {e}")
542
+ return False
543
+
544
+ def update_chat_title(conn: sqlite3.Connection, chat_id: int, new_title: str) -> bool:
545
+ """Update the title of a chat session."""
546
+ try:
547
+ with conn_lock:
548
+ cursor = conn.cursor()
549
+ cursor.execute('''
550
+ UPDATE chats
551
+ SET title = ?, last_updated = CURRENT_TIMESTAMP
552
+ WHERE id = ?
553
+ ''', (new_title, chat_id))
554
+ conn.commit()
555
+ return True
556
+
557
+ except sqlite3.Error as e:
558
+ st.error(f"Error updating chat title: {e}")
559
+ return False
560
+
561
+ def get_chat_by_id(conn: sqlite3.Connection, chat_id: int) -> Optional[Dict]:
562
+ """Retrieve a specific chat session by ID."""
563
+ try:
564
+ with conn_lock:
565
+ cursor = conn.cursor()
566
+ cursor.execute('''
567
+ SELECT
568
+ c.id,
569
+ c.title,
570
+ c.created_at,
571
+ c.last_updated,
572
+ c.collection_id,
573
+ COUNT(m.id) as message_count
574
+ FROM chats c
575
+ LEFT JOIN chat_messages m ON c.id = m.chat_id
576
+ WHERE c.id = ?
577
+ GROUP BY c.id
578
+ ''', (chat_id,))
579
+
580
+ row = cursor.fetchone()
581
+ if row:
582
+ return {
583
+ 'id': row[0],
584
+ 'title': row[1],
585
+ 'created_at': row[2],
586
+ 'last_updated': row[3],
587
+ 'collection_id': row[4],
588
+ 'message_count': row[5]
589
+ }
590
+ return None
591
+
592
+ except sqlite3.Error as e:
593
+ st.error(f"Error retrieving chat: {e}")
594
+ return None
595
+
596
+ def export_chat_history(conn: sqlite3.Connection, chat_id: int) -> Optional[Dict]:
597
+ """Export a chat session with all its messages."""
598
+ try:
599
+ chat = get_chat_by_id(conn, chat_id)
600
+ if not chat:
601
+ return None
602
+
603
+ messages = get_chat_messages(conn, chat_id)
604
+
605
+ return {
606
+ 'chat_info': chat,
607
+ 'messages': [
608
+ {
609
+ 'role': 'human' if isinstance(msg, HumanMessage) else 'assistant',
610
+ 'content': msg.content,
611
+ 'metadata': msg.additional_kwargs.get('metadata') if isinstance(msg, AIMessage) else None
612
+ }
613
+ for msg in messages
614
+ ]
615
+ }
616
+
617
+ except Exception as e:
618
+ st.error(f"Error exporting chat history: {e}")
619
+ return None
620
+
621
+ def import_chat_history(conn: sqlite3.Connection, chat_data: Dict) -> Optional[int]:
622
+ """Import a chat session from exported data."""
623
+ try:
624
+ with conn_lock:
625
+ # Create new chat
626
+ chat_id = create_new_chat(
627
+ conn,
628
+ chat_data['chat_info']['title'],
629
+ chat_data['chat_info'].get('collection_id')
630
+ )
631
+
632
+ if not chat_id:
633
+ return None
634
+
635
+ # Import messages
636
+ for msg in chat_data['messages']:
637
+ save_chat_message(
638
+ conn,
639
+ chat_id,
640
+ msg['role'],
641
+ msg['content'],
642
+ msg.get('metadata')
643
+ )
644
+
645
+ return chat_id
646
+
647
+ except Exception as e:
648
+ st.error(f"Error importing chat history: {e}")
649
+ return None
650
+
651
+
652
+ # utils/database.py
653
+ def create_chat_tables(conn):
654
+ """Create tables for chat management."""
655
+ cursor = conn.cursor()
656
+ cursor.execute('''
657
+ CREATE TABLE IF NOT EXISTS chats (
658
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
659
+ title TEXT NOT NULL,
660
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
661
+ )
662
+ ''')
663
+
664
+ cursor.execute('''
665
+ CREATE TABLE IF NOT EXISTS chat_messages (
666
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
667
+ chat_id INTEGER,
668
+ role TEXT NOT NULL,
669
+ content TEXT NOT NULL,
670
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
671
+ FOREIGN KEY (chat_id) REFERENCES chats (id) ON DELETE CASCADE
672
+ )
673
+ ''')
674
+ conn.commit()
675
+
676
+ def save_chat(conn, chat_title: str, messages: List[Dict]):
677
+ """Save chat history."""
678
+ cursor = conn.cursor()
679
+ cursor.execute('INSERT INTO chats (title) VALUES (?)', (chat_title,))
680
+ chat_id = cursor.lastrowid
681
+
682
+ for msg in messages:
683
+ cursor.execute('''
684
+ INSERT INTO chat_messages (chat_id, role, content)
685
+ VALUES (?, ?, ?)
686
+ ''', (chat_id, msg['role'], msg['content']))
687
+
688
+ conn.commit()
689
+ return chat_id
690
+
691
+ # components/chat.py
692
+ def display_chat_manager():
693
+ """Display chat management interface."""
694
+ st.sidebar.markdown("### Chat Management")
695
+
696
+ # New chat button
697
+ if st.sidebar.button("New Chat"):
698
+ st.session_state.messages = []
699
+ st.session_state.current_chat_id = None
700
+
701
+ # Save current chat
702
+ if st.session_state.messages and st.sidebar.button("Save Chat"):
703
+ chat_title = st.sidebar.text_input("Chat Title",
704
+ value=f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}")
705
+ if chat_title:
706
+ save_chat(st.session_state.db_conn, chat_title, st.session_state.messages)
707
+ st.sidebar.success("Chat saved!")
708
+
709
+ # Load previous chats
710
+ chats = get_all_chats(st.session_state.db_conn)
711
+ if chats:
712
+ st.sidebar.markdown("### Previous Chats")
713
+ for chat in chats:
714
+ if st.sidebar.button(f"📜 {chat['title']}", key=f"chat_{chat['id']}"):
715
+ st.session_state.messages = get_chat_messages(st.session_state.db_conn, chat['id'])
716
+ st.session_state.current_chat_id = chat['id']
717
+ st.rerun()
718
  def add_annotation(conn: sqlite3.Connection, document_id: int, annotation: str, page_number: Optional[int] = None) -> bool:
719
  """Add an annotation to a document."""
720
  try: