|
|
|
|
|
""" |
|
|
Database module for Multi-Personality Chat Bot |
|
|
Handles SQLite database operations for chat history and analytics. |
|
|
""" |
|
|
|
|
|
import sqlite3 |
|
|
import os |
|
|
import tempfile |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from typing import List, Dict, Optional, Any |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ChatDatabase: |
|
|
"""Simple SQLite database for chat storage""" |
|
|
|
|
|
def __init__(self, db_path: str = None): |
|
|
"""Initialize database connection""" |
|
|
|
|
|
resolved_path = db_path or os.getenv("DB_PATH") or os.path.join(os.getcwd(), "data", "chat_data.db") |
|
|
|
|
|
try: |
|
|
parent_dir = os.path.dirname(resolved_path) or "." |
|
|
os.makedirs(parent_dir, exist_ok=True) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not create DB directory '{resolved_path}': {e}") |
|
|
self.db_path = resolved_path |
|
|
self.initialize_database() |
|
|
|
|
|
def initialize_database(self): |
|
|
"""Create database tables if they don't exist""" |
|
|
def _create_schema(conn: sqlite3.Connection): |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS messages ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
username TEXT NOT NULL, |
|
|
message TEXT NOT NULL, |
|
|
personality_type TEXT NOT NULL, |
|
|
sender_type TEXT NOT NULL CHECK(sender_type IN ('user', 'bot')), |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, |
|
|
session_id TEXT, |
|
|
response_time REAL |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
self._ensure_messages_columns(conn) |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS personality_stats ( |
|
|
personality_type TEXT PRIMARY KEY, |
|
|
total_messages INTEGER DEFAULT 0, |
|
|
avg_response_time REAL DEFAULT 0.0, |
|
|
last_used DATETIME DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
conn.commit() |
|
|
logger.info("Database initialized successfully") |
|
|
|
|
|
|
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
_create_schema(conn) |
|
|
except sqlite3.OperationalError as e: |
|
|
if "unable to open database file" in str(e).lower(): |
|
|
|
|
|
try: |
|
|
parent_dir = os.path.dirname(self.db_path) or "." |
|
|
os.makedirs(parent_dir, exist_ok=True) |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
_create_schema(conn) |
|
|
return |
|
|
except Exception: |
|
|
|
|
|
tmp_dir = tempfile.gettempdir() |
|
|
fallback_path = os.path.join(tmp_dir, "chat_data.db") |
|
|
try: |
|
|
os.makedirs(tmp_dir, exist_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
with sqlite3.connect(fallback_path) as conn: |
|
|
self.db_path = fallback_path |
|
|
_create_schema(conn) |
|
|
logger.warning(f"DB path fallback in use: {fallback_path}") |
|
|
return |
|
|
except Exception as e2: |
|
|
logger.error(f"Database initialization error (fallback failed): {e2}") |
|
|
return |
|
|
else: |
|
|
logger.error(f"Database initialization error: {str(e)}") |
|
|
except Exception as e: |
|
|
logger.error(f"Database initialization error: {str(e)}") |
|
|
|
|
|
def _ensure_messages_columns(self, conn: sqlite3.Connection) -> None: |
|
|
"""Ensure required columns exist in messages table for backward compatibility.""" |
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(messages)") |
|
|
cols = {row[1]: row for row in cursor.fetchall()} |
|
|
|
|
|
|
|
|
required_columns = { |
|
|
'personality_type': "ALTER TABLE messages ADD COLUMN personality_type TEXT", |
|
|
'sender_type': "ALTER TABLE messages ADD COLUMN sender_type TEXT", |
|
|
'session_id': "ALTER TABLE messages ADD COLUMN session_id TEXT", |
|
|
'response_time': "ALTER TABLE messages ADD COLUMN response_time REAL", |
|
|
'timestamp': "ALTER TABLE messages ADD COLUMN timestamp DATETIME DEFAULT CURRENT_TIMESTAMP", |
|
|
} |
|
|
|
|
|
for name, alter_sql in required_columns.items(): |
|
|
if name not in cols: |
|
|
try: |
|
|
cursor.execute(alter_sql) |
|
|
logger.info(f"Added missing column to messages table: {name}") |
|
|
except Exception as e: |
|
|
|
|
|
logger.warning(f"Could not add column '{name}' to messages: {e}") |
|
|
conn.commit() |
|
|
except Exception as e: |
|
|
logger.error(f"Error ensuring messages columns: {e}") |
|
|
|
|
|
def save_message(self, username: str, message: str, personality_type: str, |
|
|
sender_type: str, session_id: str = None, response_time: float = None) -> bool: |
|
|
"""Save a message to the database""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(""" |
|
|
INSERT INTO messages (username, message, personality_type, sender_type, session_id, response_time) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", (username, message, personality_type, sender_type, session_id, response_time)) |
|
|
|
|
|
conn.commit() |
|
|
return True |
|
|
|
|
|
except sqlite3.OperationalError as e: |
|
|
|
|
|
if "no column named" in str(e).lower(): |
|
|
logger.warning(f"Schema issue detected ('{e}'). Attempting to migrate and retry insert...") |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
self._ensure_messages_columns(conn) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
""" |
|
|
INSERT INTO messages (username, message, personality_type, sender_type, session_id, response_time) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", |
|
|
(username, message, personality_type, sender_type, session_id, response_time), |
|
|
) |
|
|
conn.commit() |
|
|
logger.info("Insert succeeded after schema migration") |
|
|
return True |
|
|
except Exception as e2: |
|
|
logger.error(f"Retry after migration failed: {e2}") |
|
|
return False |
|
|
else: |
|
|
logger.error(f"Operational DB error saving message: {e}") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected DB error saving message: {str(e)}") |
|
|
return False |
|
|
|
|
|
def get_recent_messages(self, personality_type: str = None, limit: int = 50) -> List[Dict[str, Any]]: |
|
|
"""Get recent messages, optionally filtered by personality type""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
if personality_type: |
|
|
cursor.execute(""" |
|
|
SELECT username, message, personality_type, sender_type, timestamp |
|
|
FROM messages |
|
|
WHERE personality_type = ? |
|
|
ORDER BY timestamp DESC |
|
|
LIMIT ? |
|
|
""", (personality_type, limit)) |
|
|
else: |
|
|
cursor.execute(""" |
|
|
SELECT username, message, personality_type, sender_type, timestamp |
|
|
FROM messages |
|
|
ORDER BY timestamp DESC |
|
|
LIMIT ? |
|
|
""", (limit,)) |
|
|
|
|
|
messages = [] |
|
|
for row in cursor.fetchall(): |
|
|
messages.append({ |
|
|
'username': row[0], |
|
|
'message': row[1], |
|
|
'personality_type': row[2], |
|
|
'sender_type': row[3], |
|
|
'timestamp': row[4] |
|
|
}) |
|
|
|
|
|
return messages |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error retrieving messages: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def clear_personality_chat(self, personality_type: str, username: str = None) -> bool: |
|
|
"""Clear chat history for a personality (optionally for a specific user)""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
if username: |
|
|
cursor.execute(""" |
|
|
DELETE FROM messages |
|
|
WHERE personality_type = ? AND username = ? |
|
|
""", (personality_type, username)) |
|
|
else: |
|
|
cursor.execute(""" |
|
|
DELETE FROM messages |
|
|
WHERE personality_type = ? |
|
|
""", (personality_type,)) |
|
|
|
|
|
conn.commit() |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error clearing chat: {str(e)}") |
|
|
return False |
|
|
|
|
|
def get_personality_stats(self, personality_type: str) -> Dict[str, Any]: |
|
|
"""Get statistics for a specific personality""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT COUNT(*) FROM messages |
|
|
WHERE personality_type = ? AND sender_type = 'bot' |
|
|
""", (personality_type,)) |
|
|
|
|
|
message_count = cursor.fetchone()[0] |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT AVG(response_time) FROM messages |
|
|
WHERE personality_type = ? AND sender_type = 'bot' AND response_time IS NOT NULL |
|
|
""", (personality_type,)) |
|
|
|
|
|
avg_response_time = cursor.fetchone()[0] or 0.0 |
|
|
|
|
|
return { |
|
|
'personality_type': personality_type, |
|
|
'total_messages': message_count, |
|
|
'avg_response_time': round(avg_response_time, 2) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting personality stats: {str(e)}") |
|
|
return { |
|
|
'personality_type': personality_type, |
|
|
'total_messages': 0, |
|
|
'avg_response_time': 0.0 |
|
|
} |
|
|
|
|
|
def get_all_stats(self) -> Dict[str, Any]: |
|
|
"""Get overall statistics""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM messages") |
|
|
total_messages = cursor.fetchone()[0] |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT personality_type, COUNT(*) |
|
|
FROM messages |
|
|
GROUP BY personality_type |
|
|
""") |
|
|
|
|
|
personality_breakdown = dict(cursor.fetchall()) |
|
|
|
|
|
return { |
|
|
'total_messages': total_messages, |
|
|
'personality_breakdown': personality_breakdown |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting all stats: {str(e)}") |
|
|
return { |
|
|
'total_messages': 0, |
|
|
'personality_breakdown': {} |
|
|
} |
|
|
|