| | import sqlite3 |
| | import json |
| | from contextlib import contextmanager |
| | from typing import List, Dict, Any, Tuple |
| | from config import DB_PATH |
| |
|
| | @contextmanager |
| | def get_db_connection(): |
| | """Context manager for database connections.""" |
| | conn = sqlite3.connect(DB_PATH) |
| | conn.row_factory = sqlite3.Row |
| | try: |
| | yield conn |
| | finally: |
| | conn.close() |
| |
|
| | def fetch_all_embeddings(table: str) -> List[Tuple[int, str, List[float]]]: |
| | """Fetch all embeddings from a table.""" |
| | with get_db_connection() as conn: |
| | cur = conn.cursor() |
| | cur.execute(f"SELECT id, full_text, embedding FROM {table}") |
| | rows = cur.fetchall() |
| | |
| | parsed = [] |
| | for row in rows: |
| | try: |
| | parsed.append((row['id'], row['full_text'], json.loads(row['embedding']))) |
| | except (json.JSONDecodeError, TypeError): |
| | continue |
| | return parsed |
| |
|
| | def fetch_row_by_id(table: str, row_id: int) -> Dict[str, Any]: |
| | """Fetch a single row by ID.""" |
| | with get_db_connection() as conn: |
| | cur = conn.cursor() |
| | cur.execute(f"SELECT * FROM {table} WHERE id = ?", (row_id,)) |
| | row = cur.fetchone() |
| | return dict(row) if row else {} |
| |
|
| | def fetch_all_faq_embeddings() -> List[Tuple[int, str, str, List[float]]]: |
| | """Fetch all FAQ embeddings.""" |
| | with get_db_connection() as conn: |
| | cur = conn.cursor() |
| | cur.execute("SELECT id, question, answer, embedding FROM faq_entries") |
| | rows = cur.fetchall() |
| | |
| | parsed = [] |
| | for row in rows: |
| | try: |
| | parsed.append((row['id'], row['question'], row['answer'], json.loads(row['embedding']))) |
| | except (json.JSONDecodeError, TypeError): |
| | continue |
| | return parsed |
| |
|
| | def log_question(question: str, session_id: str = None, category: str = None, answer: str = None): |
| | """Log a user question to the database with full context.""" |
| | with get_db_connection() as conn: |
| | cur = conn.cursor() |
| | |
| | |
| | |
| | |
| | |
| | try: |
| | cur.execute(""" |
| | INSERT INTO question_logs (session_id, question, category, answer) |
| | VALUES (?, ?, ?, ?) |
| | """, (session_id, question, category, answer)) |
| | except sqlite3.OperationalError: |
| | |
| | cur.execute("INSERT INTO question_logs (question) VALUES (?)", (question,)) |
| | |
| | conn.commit() |
| |
|
| | def get_session_state(session_id: str) -> Dict[str, Any]: |
| | """Get session state from DB""" |
| | with get_db_connection() as conn: |
| | cur = conn.cursor() |
| | cur.execute("SELECT * FROM user_sessions WHERE session_id = ?", (session_id,)) |
| | row = cur.fetchone() |
| | if row: |
| | return dict(row) |
| | return {"preference": None, "msg_count": 0, "clarification_count": 0, "knowledge_context": "{}"} |
| |
|
| | def update_session_state(session_id: str, preference: str = None, increment_count: bool = True, increment_clarification: bool = False, reset_clarification: bool = False, knowledge_update: Dict = None): |
| | """Update session state with Knowledge Dictionary support""" |
| | with get_db_connection() as conn: |
| | cur = conn.cursor() |
| | |
| | |
| | cur.execute("SELECT preference, msg_count, clarification_count, knowledge_context FROM user_sessions WHERE session_id = ?", (session_id,)) |
| | row = cur.fetchone() |
| | |
| | current_knowledge = {} |
| | if row: |
| | curr_pref, curr_count, curr_clarification, curr_knowledge_json = row |
| | try: |
| | current_knowledge = json.loads(curr_knowledge_json) |
| | except: |
| | current_knowledge = {} |
| |
|
| | new_pref = preference if preference else curr_pref |
| | new_count = curr_count + 1 if increment_count else curr_count |
| | |
| | |
| | if new_count > 10: |
| | print(f"๐ Session {session_id} reached 10 messages. Resetting memory context.") |
| | new_count = 1 |
| | new_pref = None |
| | current_knowledge = {} |
| | new_clarification = 0 |
| | else: |
| | new_clarification = curr_clarification |
| | if reset_clarification: |
| | new_clarification = 0 |
| | elif increment_clarification: |
| | new_clarification = curr_clarification + 1 |
| | |
| | |
| | if knowledge_update: |
| | current_knowledge.update(knowledge_update) |
| | |
| | new_knowledge_json = json.dumps(current_knowledge) |
| |
|
| | cur.execute(""" |
| | UPDATE user_sessions |
| | SET preference = ?, msg_count = ?, clarification_count = ?, knowledge_context = ?, last_updated = CURRENT_TIMESTAMP |
| | WHERE session_id = ? |
| | """, (new_pref, new_count, new_clarification, new_knowledge_json, session_id)) |
| | else: |
| | new_pref = preference |
| | new_count = 1 if increment_count else 0 |
| | new_clarification = 1 if increment_clarification else 0 |
| | |
| | if knowledge_update: |
| | current_knowledge.update(knowledge_update) |
| | new_knowledge_json = json.dumps(current_knowledge) |
| | |
| | cur.execute(""" |
| | INSERT INTO user_sessions (session_id, preference, msg_count, clarification_count, knowledge_context) |
| | VALUES (?, ?, ?, ?, ?) |
| | """, (session_id, new_pref, new_count, new_clarification, new_knowledge_json)) |
| | |
| | conn.commit() |