File size: 5,878 Bytes
9f57d5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import sqlite3
import json
from contextlib import contextmanager
from typing import List, Dict, Any, Tuple
from config import DB_PATH
@contextmanager
def get_db_connection():
"""Context manager for database connections."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def fetch_all_embeddings(table: str) -> List[Tuple[int, str, List[float]]]:
"""Fetch all embeddings from a table."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute(f"SELECT id, full_text, embedding FROM {table}")
rows = cur.fetchall()
parsed = []
for row in rows:
try:
parsed.append((row['id'], row['full_text'], json.loads(row['embedding'])))
except (json.JSONDecodeError, TypeError):
continue
return parsed
def fetch_row_by_id(table: str, row_id: int) -> Dict[str, Any]:
"""Fetch a single row by ID."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute(f"SELECT * FROM {table} WHERE id = ?", (row_id,))
row = cur.fetchone()
return dict(row) if row else {}
def fetch_all_faq_embeddings() -> List[Tuple[int, str, str, List[float]]]:
"""Fetch all FAQ embeddings."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT id, question, answer, embedding FROM faq_entries")
rows = cur.fetchall()
parsed = []
for row in rows:
try:
parsed.append((row['id'], row['question'], row['answer'], json.loads(row['embedding'])))
except (json.JSONDecodeError, TypeError):
continue
return parsed
def log_question(question: str, session_id: str = None, category: str = None, answer: str = None):
"""Log a user question to the database with full context."""
with get_db_connection() as conn:
cur = conn.cursor()
# Check if table has the new columns, if not just log question (migration safety)
# Or better, just try insert with all columns assuming schema is up to date or we updated it.
# Given schema.sql suggests full schema, we'll try full insert.
try:
cur.execute("""
INSERT INTO question_logs (session_id, question, category, answer)
VALUES (?, ?, ?, ?)
""", (session_id, question, category, answer))
except sqlite3.OperationalError:
# Fallback for older schema versions
cur.execute("INSERT INTO question_logs (question) VALUES (?)", (question,))
conn.commit()
def get_session_state(session_id: str) -> Dict[str, Any]:
"""Get session state from DB"""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT * FROM user_sessions WHERE session_id = ?", (session_id,))
row = cur.fetchone()
if row:
return dict(row)
return {"preference": None, "msg_count": 0, "clarification_count": 0, "knowledge_context": "{}"}
def update_session_state(session_id: str, preference: str = None, increment_count: bool = True, increment_clarification: bool = False, reset_clarification: bool = False, knowledge_update: Dict = None):
"""Update session state with Knowledge Dictionary support"""
with get_db_connection() as conn:
cur = conn.cursor()
# Check if exists
cur.execute("SELECT preference, msg_count, clarification_count, knowledge_context FROM user_sessions WHERE session_id = ?", (session_id,))
row = cur.fetchone()
current_knowledge = {}
if row:
curr_pref, curr_count, curr_clarification, curr_knowledge_json = row
try:
current_knowledge = json.loads(curr_knowledge_json)
except:
current_knowledge = {}
new_pref = preference if preference else curr_pref
new_count = curr_count + 1 if increment_count else curr_count
# 10-Message Memory Rule: Reset if we hit the limit
if new_count > 10:
print(f"🔄 Session {session_id} reached 10 messages. Resetting memory context.")
new_count = 1
new_pref = None
current_knowledge = {}
new_clarification = 0
else:
new_clarification = curr_clarification
if reset_clarification:
new_clarification = 0
elif increment_clarification:
new_clarification = curr_clarification + 1
# Merge knowledge updates
if knowledge_update:
current_knowledge.update(knowledge_update)
new_knowledge_json = json.dumps(current_knowledge)
cur.execute("""
UPDATE user_sessions
SET preference = ?, msg_count = ?, clarification_count = ?, knowledge_context = ?, last_updated = CURRENT_TIMESTAMP
WHERE session_id = ?
""", (new_pref, new_count, new_clarification, new_knowledge_json, session_id))
else:
new_pref = preference
new_count = 1 if increment_count else 0
new_clarification = 1 if increment_clarification else 0
if knowledge_update:
current_knowledge.update(knowledge_update)
new_knowledge_json = json.dumps(current_knowledge)
cur.execute("""
INSERT INTO user_sessions (session_id, preference, msg_count, clarification_count, knowledge_context)
VALUES (?, ?, ?, ?, ?)
""", (session_id, new_pref, new_count, new_clarification, new_knowledge_json))
conn.commit() |