private-ai-backend / database.py
adeebjamal's picture
renamed column
308fcc4
import os
import psycopg2
from psycopg2.extras import RealDictCursor
import logging
import query_constants
logger = logging.getLogger(__name__)
def get_db_connection():
"""Create and return a database connection."""
db_url = os.environ.get("DATABASE_URL")
if not db_url:
raise ValueError("DATABASE_URL environment variable is not set")
try:
conn = psycopg2.connect(db_url, cursor_factory=RealDictCursor)
return conn
except Exception as e:
logger.error(f"Error connecting to database: {e}")
raise
def init_db():
"""Create tables if they don't exist."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
# Table 1: Conversations
cursor.execute(query_constants.CREATE_CONVERSATIONS_TABLE)
# Table 2: Messages
cursor.execute(query_constants.CREATE_MESSAGES_TABLE)
conn.commit()
logger.info("Database initialized successfully.")
except Exception as e:
logger.error(f"Error initializing database: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def create_conversation(title: str):
"""Insert new conversation, return it."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
query_constants.INSERT_CONVERSATION,
(title,)
)
new_conv = cursor.fetchone()
conn.commit()
return dict(new_conv)
except Exception as e:
logger.error(f"Error creating conversation: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def get_all_conversations():
"""Return all conversations with message count."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(query_constants.GET_ALL_CONVERSATIONS)
conversations = cursor.fetchall()
return [dict(conv) for conv in conversations]
except Exception as e:
logger.error(f"Error getting all conversations: {e}")
raise
finally:
if conn:
conn.close()
def get_conversation(id: int):
"""Return single conversation by ID."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(query_constants.GET_CONVERSATION_BY_ID, (id,))
conv = cursor.fetchone()
return dict(conv) if conv else None
except Exception as e:
logger.error(f"Error getting conversation by id: {e}")
raise
finally:
if conn:
conn.close()
def save_message(conv_id: int, query: str, resp: str):
"""Insert one Q&A row into messages."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
query_constants.INSERT_MESSAGE,
(conv_id, query, resp)
)
new_msg = cursor.fetchone()
conn.commit()
return dict(new_msg)
except Exception as e:
logger.error(f"Error saving message: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def update_message_response(message_id: int, new_response: str):
"""Update the response column of an existing message row."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
query_constants.UPDATE_MESSAGE_RESPONSE,
(new_response, message_id)
)
updated_msg = cursor.fetchone()
conn.commit()
return dict(updated_msg) if updated_msg else None
except Exception as e:
logger.error(f"Error updating message response: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def update_conversation_updated_at(conv_id: int):
"""Stamp updated_at = NOW() on the given conversation."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(query_constants.UPDATE_CONVERSATION_UPDATED_AT, (conv_id,))
conn.commit()
logger.info(f"Touched updated_at for conversation {conv_id}")
except Exception as e:
logger.error(f"Error updating conversation updated_at: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def get_messages(conv_id: int, limit: int = 10):
"""Return last N messages as LLM history format."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
query_constants.GET_MESSAGES_FOR_LLM,
(conv_id, limit)
)
# Fetching returns descending order (latest first), we need chronological for LLM prompt
rows = cursor.fetchall()
rows.reverse()
history = []
for row in rows:
history.append({"role": "user", "content": row["user_query"]})
history.append({"role": "assistant", "content": row["response"]})
return history
except Exception as e:
logger.error(f"Error getting messages for LLM context: {e}")
raise
finally:
if conn:
conn.close()
def get_messages_paginated(conv_id: int, start_row: int, end_row: int):
"""Return paginated messages."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
# Calculate pagination
offset = start_row - 1
limit = end_row - start_row + 1
# Get total messages count
cursor.execute(query_constants.COUNT_MESSAGES, (conv_id,))
total_messages = cursor.fetchone()["total"]
# Get paginated messages
cursor.execute(
query_constants.GET_PAGINATED_MESSAGES,
(conv_id, offset, limit)
)
messages = [dict(row) for row in cursor.fetchall()]
return {
"total_messages": total_messages,
"messages": messages
}
except Exception as e:
logger.error(f"Error getting paginated messages: {e}")
raise
finally:
if conn:
conn.close()
def rename_conversation(conv_id: int, new_name: str):
"""Update a conversation's title."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(query_constants.RENAME_CONVERSATION, (new_name, conv_id))
updated = cursor.fetchone()
conn.commit()
return dict(updated) if updated else None
except Exception as e:
logger.error(f"Error renaming conversation: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()
def delete_conversation(conv_id: int):
"""Delete a conversation and all its messages."""
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
# Delete messages first (also handled by CASCADE, but explicit is safer)
cursor.execute(query_constants.DELETE_MESSAGES_BY_CONVERSATION, (conv_id,))
# Delete the conversation itself
cursor.execute(query_constants.DELETE_CONVERSATION, (conv_id,))
deleted = cursor.fetchone()
conn.commit()
return dict(deleted) if deleted else None
except Exception as e:
logger.error(f"Error deleting conversation: {e}")
if conn:
conn.rollback()
raise
finally:
if conn:
conn.close()