Spaces:
Build error
Build error
Update utils/database.py
Browse files- utils/database.py +351 -0
utils/database.py
CHANGED
|
@@ -364,6 +364,357 @@ def add_query(conn: sqlite3.Connection, query: str, response: str, document_id:
|
|
| 364 |
st.error(f"Error adding query: {e}")
|
| 365 |
return False
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
def add_annotation(conn: sqlite3.Connection, document_id: int, annotation: str, page_number: Optional[int] = None) -> bool:
|
| 368 |
"""Add an annotation to a document."""
|
| 369 |
try:
|
|
|
|
| 364 |
st.error(f"Error adding query: {e}")
|
| 365 |
return False
|
| 366 |
|
| 367 |
+
# Add to utils/database.py
|
| 368 |
+
|
| 369 |
+
import sqlite3
|
| 370 |
+
from typing import List, Dict, Optional
|
| 371 |
+
from datetime import datetime
|
| 372 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 373 |
+
import streamlit as st
|
| 374 |
+
|
| 375 |
+
def create_chat_tables(conn: sqlite3.Connection) -> None:
|
| 376 |
+
"""Create necessary tables for chat management."""
|
| 377 |
+
try:
|
| 378 |
+
with conn_lock:
|
| 379 |
+
cursor = conn.cursor()
|
| 380 |
+
|
| 381 |
+
# Create chats table
|
| 382 |
+
cursor.execute('''
|
| 383 |
+
CREATE TABLE IF NOT EXISTS chats (
|
| 384 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 385 |
+
title TEXT NOT NULL,
|
| 386 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 387 |
+
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 388 |
+
collection_id INTEGER,
|
| 389 |
+
FOREIGN KEY (collection_id) REFERENCES collections (id) ON DELETE SET NULL
|
| 390 |
+
)
|
| 391 |
+
''')
|
| 392 |
+
|
| 393 |
+
# Create chat messages table
|
| 394 |
+
cursor.execute('''
|
| 395 |
+
CREATE TABLE IF NOT EXISTS chat_messages (
|
| 396 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 397 |
+
chat_id INTEGER NOT NULL,
|
| 398 |
+
role TEXT NOT NULL,
|
| 399 |
+
content TEXT NOT NULL,
|
| 400 |
+
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 401 |
+
metadata TEXT, -- Store metadata as JSON string
|
| 402 |
+
FOREIGN KEY (chat_id) REFERENCES chats (id) ON DELETE CASCADE
|
| 403 |
+
)
|
| 404 |
+
''')
|
| 405 |
+
|
| 406 |
+
conn.commit()
|
| 407 |
+
|
| 408 |
+
except sqlite3.Error as e:
|
| 409 |
+
st.error(f"Error creating chat tables: {e}")
|
| 410 |
+
raise
|
| 411 |
+
|
| 412 |
+
def create_new_chat(conn: sqlite3.Connection, title: str, collection_id: Optional[int] = None) -> Optional[int]:
|
| 413 |
+
"""Create a new chat session."""
|
| 414 |
+
try:
|
| 415 |
+
with conn_lock:
|
| 416 |
+
cursor = conn.cursor()
|
| 417 |
+
cursor.execute('''
|
| 418 |
+
INSERT INTO chats (title, collection_id, created_at, last_updated)
|
| 419 |
+
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
| 420 |
+
''', (title, collection_id))
|
| 421 |
+
|
| 422 |
+
conn.commit()
|
| 423 |
+
return cursor.lastrowid
|
| 424 |
+
|
| 425 |
+
except sqlite3.Error as e:
|
| 426 |
+
st.error(f"Error creating new chat: {e}")
|
| 427 |
+
return None
|
| 428 |
+
|
| 429 |
+
def save_chat_message(conn: sqlite3.Connection,
|
| 430 |
+
chat_id: int,
|
| 431 |
+
role: str,
|
| 432 |
+
content: str,
|
| 433 |
+
metadata: Optional[Dict] = None) -> Optional[int]:
|
| 434 |
+
"""Save a chat message to the database."""
|
| 435 |
+
try:
|
| 436 |
+
with conn_lock:
|
| 437 |
+
cursor = conn.cursor()
|
| 438 |
+
|
| 439 |
+
# Convert metadata to JSON string if present
|
| 440 |
+
metadata_str = json.dumps(metadata) if metadata else None
|
| 441 |
+
|
| 442 |
+
# Insert message
|
| 443 |
+
cursor.execute('''
|
| 444 |
+
INSERT INTO chat_messages (chat_id, role, content, metadata, timestamp)
|
| 445 |
+
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
| 446 |
+
''', (chat_id, role, content, metadata_str))
|
| 447 |
+
|
| 448 |
+
# Update chat last_updated timestamp
|
| 449 |
+
cursor.execute('''
|
| 450 |
+
UPDATE chats
|
| 451 |
+
SET last_updated = CURRENT_TIMESTAMP
|
| 452 |
+
WHERE id = ?
|
| 453 |
+
''', (chat_id,))
|
| 454 |
+
|
| 455 |
+
conn.commit()
|
| 456 |
+
return cursor.lastrowid
|
| 457 |
+
|
| 458 |
+
except sqlite3.Error as e:
|
| 459 |
+
st.error(f"Error saving chat message: {e}")
|
| 460 |
+
return None
|
| 461 |
+
|
| 462 |
+
def get_all_chats(conn: sqlite3.Connection) -> List[Dict]:
|
| 463 |
+
"""Retrieve all chat sessions."""
|
| 464 |
+
try:
|
| 465 |
+
with conn_lock:
|
| 466 |
+
cursor = conn.cursor()
|
| 467 |
+
cursor.execute('''
|
| 468 |
+
SELECT
|
| 469 |
+
c.id,
|
| 470 |
+
c.title,
|
| 471 |
+
c.created_at,
|
| 472 |
+
c.last_updated,
|
| 473 |
+
c.collection_id,
|
| 474 |
+
COUNT(m.id) as message_count,
|
| 475 |
+
MAX(m.timestamp) as last_message
|
| 476 |
+
FROM chats c
|
| 477 |
+
LEFT JOIN chat_messages m ON c.id = m.chat_id
|
| 478 |
+
GROUP BY c.id
|
| 479 |
+
ORDER BY c.last_updated DESC
|
| 480 |
+
''')
|
| 481 |
+
|
| 482 |
+
chats = []
|
| 483 |
+
for row in cursor.fetchall():
|
| 484 |
+
chats.append({
|
| 485 |
+
'id': row[0],
|
| 486 |
+
'title': row[1],
|
| 487 |
+
'created_at': row[2],
|
| 488 |
+
'last_updated': row[3],
|
| 489 |
+
'collection_id': row[4],
|
| 490 |
+
'message_count': row[5],
|
| 491 |
+
'last_message': row[6]
|
| 492 |
+
})
|
| 493 |
+
return chats
|
| 494 |
+
|
| 495 |
+
except sqlite3.Error as e:
|
| 496 |
+
st.error(f"Error retrieving chats: {e}")
|
| 497 |
+
return []
|
| 498 |
+
|
| 499 |
+
def get_chat_messages(conn: sqlite3.Connection, chat_id: int) -> List[Dict]:
|
| 500 |
+
"""Retrieve all messages for a specific chat."""
|
| 501 |
+
try:
|
| 502 |
+
with conn_lock:
|
| 503 |
+
cursor = conn.cursor()
|
| 504 |
+
cursor.execute('''
|
| 505 |
+
SELECT id, role, content, metadata, timestamp
|
| 506 |
+
FROM chat_messages
|
| 507 |
+
WHERE chat_id = ?
|
| 508 |
+
ORDER BY timestamp
|
| 509 |
+
''', (chat_id,))
|
| 510 |
+
|
| 511 |
+
messages = []
|
| 512 |
+
for row in cursor.fetchall():
|
| 513 |
+
# Parse metadata JSON if present
|
| 514 |
+
metadata = json.loads(row[3]) if row[3] else None
|
| 515 |
+
|
| 516 |
+
# Convert to appropriate message type
|
| 517 |
+
if row[1] == 'human':
|
| 518 |
+
message = HumanMessage(content=row[2])
|
| 519 |
+
else:
|
| 520 |
+
message = AIMessage(content=row[2], additional_kwargs={'metadata': metadata})
|
| 521 |
+
|
| 522 |
+
messages.append(message)
|
| 523 |
+
|
| 524 |
+
return messages
|
| 525 |
+
|
| 526 |
+
except sqlite3.Error as e:
|
| 527 |
+
st.error(f"Error retrieving chat messages: {e}")
|
| 528 |
+
return []
|
| 529 |
+
|
| 530 |
+
def delete_chat(conn: sqlite3.Connection, chat_id: int) -> bool:
|
| 531 |
+
"""Delete a chat session and all its messages."""
|
| 532 |
+
try:
|
| 533 |
+
with conn_lock:
|
| 534 |
+
cursor = conn.cursor()
|
| 535 |
+
# Messages will be automatically deleted due to CASCADE
|
| 536 |
+
cursor.execute('DELETE FROM chats WHERE id = ?', (chat_id,))
|
| 537 |
+
conn.commit()
|
| 538 |
+
return True
|
| 539 |
+
|
| 540 |
+
except sqlite3.Error as e:
|
| 541 |
+
st.error(f"Error deleting chat: {e}")
|
| 542 |
+
return False
|
| 543 |
+
|
| 544 |
+
def update_chat_title(conn: sqlite3.Connection, chat_id: int, new_title: str) -> bool:
|
| 545 |
+
"""Update the title of a chat session."""
|
| 546 |
+
try:
|
| 547 |
+
with conn_lock:
|
| 548 |
+
cursor = conn.cursor()
|
| 549 |
+
cursor.execute('''
|
| 550 |
+
UPDATE chats
|
| 551 |
+
SET title = ?, last_updated = CURRENT_TIMESTAMP
|
| 552 |
+
WHERE id = ?
|
| 553 |
+
''', (new_title, chat_id))
|
| 554 |
+
conn.commit()
|
| 555 |
+
return True
|
| 556 |
+
|
| 557 |
+
except sqlite3.Error as e:
|
| 558 |
+
st.error(f"Error updating chat title: {e}")
|
| 559 |
+
return False
|
| 560 |
+
|
| 561 |
+
def get_chat_by_id(conn: sqlite3.Connection, chat_id: int) -> Optional[Dict]:
|
| 562 |
+
"""Retrieve a specific chat session by ID."""
|
| 563 |
+
try:
|
| 564 |
+
with conn_lock:
|
| 565 |
+
cursor = conn.cursor()
|
| 566 |
+
cursor.execute('''
|
| 567 |
+
SELECT
|
| 568 |
+
c.id,
|
| 569 |
+
c.title,
|
| 570 |
+
c.created_at,
|
| 571 |
+
c.last_updated,
|
| 572 |
+
c.collection_id,
|
| 573 |
+
COUNT(m.id) as message_count
|
| 574 |
+
FROM chats c
|
| 575 |
+
LEFT JOIN chat_messages m ON c.id = m.chat_id
|
| 576 |
+
WHERE c.id = ?
|
| 577 |
+
GROUP BY c.id
|
| 578 |
+
''', (chat_id,))
|
| 579 |
+
|
| 580 |
+
row = cursor.fetchone()
|
| 581 |
+
if row:
|
| 582 |
+
return {
|
| 583 |
+
'id': row[0],
|
| 584 |
+
'title': row[1],
|
| 585 |
+
'created_at': row[2],
|
| 586 |
+
'last_updated': row[3],
|
| 587 |
+
'collection_id': row[4],
|
| 588 |
+
'message_count': row[5]
|
| 589 |
+
}
|
| 590 |
+
return None
|
| 591 |
+
|
| 592 |
+
except sqlite3.Error as e:
|
| 593 |
+
st.error(f"Error retrieving chat: {e}")
|
| 594 |
+
return None
|
| 595 |
+
|
| 596 |
+
def export_chat_history(conn: sqlite3.Connection, chat_id: int) -> Optional[Dict]:
|
| 597 |
+
"""Export a chat session with all its messages."""
|
| 598 |
+
try:
|
| 599 |
+
chat = get_chat_by_id(conn, chat_id)
|
| 600 |
+
if not chat:
|
| 601 |
+
return None
|
| 602 |
+
|
| 603 |
+
messages = get_chat_messages(conn, chat_id)
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
'chat_info': chat,
|
| 607 |
+
'messages': [
|
| 608 |
+
{
|
| 609 |
+
'role': 'human' if isinstance(msg, HumanMessage) else 'assistant',
|
| 610 |
+
'content': msg.content,
|
| 611 |
+
'metadata': msg.additional_kwargs.get('metadata') if isinstance(msg, AIMessage) else None
|
| 612 |
+
}
|
| 613 |
+
for msg in messages
|
| 614 |
+
]
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
except Exception as e:
|
| 618 |
+
st.error(f"Error exporting chat history: {e}")
|
| 619 |
+
return None
|
| 620 |
+
|
| 621 |
+
def import_chat_history(conn: sqlite3.Connection, chat_data: Dict) -> Optional[int]:
|
| 622 |
+
"""Import a chat session from exported data."""
|
| 623 |
+
try:
|
| 624 |
+
with conn_lock:
|
| 625 |
+
# Create new chat
|
| 626 |
+
chat_id = create_new_chat(
|
| 627 |
+
conn,
|
| 628 |
+
chat_data['chat_info']['title'],
|
| 629 |
+
chat_data['chat_info'].get('collection_id')
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if not chat_id:
|
| 633 |
+
return None
|
| 634 |
+
|
| 635 |
+
# Import messages
|
| 636 |
+
for msg in chat_data['messages']:
|
| 637 |
+
save_chat_message(
|
| 638 |
+
conn,
|
| 639 |
+
chat_id,
|
| 640 |
+
msg['role'],
|
| 641 |
+
msg['content'],
|
| 642 |
+
msg.get('metadata')
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
return chat_id
|
| 646 |
+
|
| 647 |
+
except Exception as e:
|
| 648 |
+
st.error(f"Error importing chat history: {e}")
|
| 649 |
+
return None
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
# utils/database.py
|
| 653 |
+
def create_chat_tables(conn):
|
| 654 |
+
"""Create tables for chat management."""
|
| 655 |
+
cursor = conn.cursor()
|
| 656 |
+
cursor.execute('''
|
| 657 |
+
CREATE TABLE IF NOT EXISTS chats (
|
| 658 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 659 |
+
title TEXT NOT NULL,
|
| 660 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 661 |
+
)
|
| 662 |
+
''')
|
| 663 |
+
|
| 664 |
+
cursor.execute('''
|
| 665 |
+
CREATE TABLE IF NOT EXISTS chat_messages (
|
| 666 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 667 |
+
chat_id INTEGER,
|
| 668 |
+
role TEXT NOT NULL,
|
| 669 |
+
content TEXT NOT NULL,
|
| 670 |
+
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 671 |
+
FOREIGN KEY (chat_id) REFERENCES chats (id) ON DELETE CASCADE
|
| 672 |
+
)
|
| 673 |
+
''')
|
| 674 |
+
conn.commit()
|
| 675 |
+
|
| 676 |
+
def save_chat(conn, chat_title: str, messages: List[Dict]):
|
| 677 |
+
"""Save chat history."""
|
| 678 |
+
cursor = conn.cursor()
|
| 679 |
+
cursor.execute('INSERT INTO chats (title) VALUES (?)', (chat_title,))
|
| 680 |
+
chat_id = cursor.lastrowid
|
| 681 |
+
|
| 682 |
+
for msg in messages:
|
| 683 |
+
cursor.execute('''
|
| 684 |
+
INSERT INTO chat_messages (chat_id, role, content)
|
| 685 |
+
VALUES (?, ?, ?)
|
| 686 |
+
''', (chat_id, msg['role'], msg['content']))
|
| 687 |
+
|
| 688 |
+
conn.commit()
|
| 689 |
+
return chat_id
|
| 690 |
+
|
| 691 |
+
# components/chat.py
|
| 692 |
+
def display_chat_manager():
|
| 693 |
+
"""Display chat management interface."""
|
| 694 |
+
st.sidebar.markdown("### Chat Management")
|
| 695 |
+
|
| 696 |
+
# New chat button
|
| 697 |
+
if st.sidebar.button("New Chat"):
|
| 698 |
+
st.session_state.messages = []
|
| 699 |
+
st.session_state.current_chat_id = None
|
| 700 |
+
|
| 701 |
+
# Save current chat
|
| 702 |
+
if st.session_state.messages and st.sidebar.button("Save Chat"):
|
| 703 |
+
chat_title = st.sidebar.text_input("Chat Title",
|
| 704 |
+
value=f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
| 705 |
+
if chat_title:
|
| 706 |
+
save_chat(st.session_state.db_conn, chat_title, st.session_state.messages)
|
| 707 |
+
st.sidebar.success("Chat saved!")
|
| 708 |
+
|
| 709 |
+
# Load previous chats
|
| 710 |
+
chats = get_all_chats(st.session_state.db_conn)
|
| 711 |
+
if chats:
|
| 712 |
+
st.sidebar.markdown("### Previous Chats")
|
| 713 |
+
for chat in chats:
|
| 714 |
+
if st.sidebar.button(f"📜 {chat['title']}", key=f"chat_{chat['id']}"):
|
| 715 |
+
st.session_state.messages = get_chat_messages(st.session_state.db_conn, chat['id'])
|
| 716 |
+
st.session_state.current_chat_id = chat['id']
|
| 717 |
+
st.rerun()
|
| 718 |
def add_annotation(conn: sqlite3.Connection, document_id: int, annotation: str, page_number: Optional[int] = None) -> bool:
|
| 719 |
"""Add an annotation to a document."""
|
| 720 |
try:
|