Spaces:
Sleeping
Sleeping
| import os | |
| import psycopg2 | |
| from psycopg2.extras import RealDictCursor | |
| import logging | |
| import query_constants | |
| logger = logging.getLogger(__name__) | |
| def get_db_connection(): | |
| """Create and return a database connection.""" | |
| db_url = os.environ.get("DATABASE_URL") | |
| if not db_url: | |
| raise ValueError("DATABASE_URL environment variable is not set") | |
| try: | |
| conn = psycopg2.connect(db_url, cursor_factory=RealDictCursor) | |
| return conn | |
| except Exception as e: | |
| logger.error(f"Error connecting to database: {e}") | |
| raise | |
| def init_db(): | |
| """Create tables if they don't exist.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| # Table 1: Conversations | |
| cursor.execute(query_constants.CREATE_CONVERSATIONS_TABLE) | |
| # Table 2: Messages | |
| cursor.execute(query_constants.CREATE_MESSAGES_TABLE) | |
| conn.commit() | |
| logger.info("Database initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Error initializing database: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def create_conversation(title: str): | |
| """Insert new conversation, return it.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| query_constants.INSERT_CONVERSATION, | |
| (title,) | |
| ) | |
| new_conv = cursor.fetchone() | |
| conn.commit() | |
| return dict(new_conv) | |
| except Exception as e: | |
| logger.error(f"Error creating conversation: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def get_all_conversations(): | |
| """Return all conversations with message count.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(query_constants.GET_ALL_CONVERSATIONS) | |
| conversations = cursor.fetchall() | |
| return [dict(conv) for conv in conversations] | |
| except Exception as e: | |
| logger.error(f"Error getting all conversations: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def get_conversation(id: int): | |
| """Return single conversation by ID.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(query_constants.GET_CONVERSATION_BY_ID, (id,)) | |
| conv = cursor.fetchone() | |
| return dict(conv) if conv else None | |
| except Exception as e: | |
| logger.error(f"Error getting conversation by id: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def save_message(conv_id: int, query: str, resp: str): | |
| """Insert one Q&A row into messages.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| query_constants.INSERT_MESSAGE, | |
| (conv_id, query, resp) | |
| ) | |
| new_msg = cursor.fetchone() | |
| conn.commit() | |
| return dict(new_msg) | |
| except Exception as e: | |
| logger.error(f"Error saving message: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def update_message_response(message_id: int, new_response: str): | |
| """Update the response column of an existing message row.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| query_constants.UPDATE_MESSAGE_RESPONSE, | |
| (new_response, message_id) | |
| ) | |
| updated_msg = cursor.fetchone() | |
| conn.commit() | |
| return dict(updated_msg) if updated_msg else None | |
| except Exception as e: | |
| logger.error(f"Error updating message response: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def update_conversation_updated_at(conv_id: int): | |
| """Stamp updated_at = NOW() on the given conversation.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(query_constants.UPDATE_CONVERSATION_UPDATED_AT, (conv_id,)) | |
| conn.commit() | |
| logger.info(f"Touched updated_at for conversation {conv_id}") | |
| except Exception as e: | |
| logger.error(f"Error updating conversation updated_at: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def get_messages(conv_id: int, limit: int = 10): | |
| """Return last N messages as LLM history format.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| query_constants.GET_MESSAGES_FOR_LLM, | |
| (conv_id, limit) | |
| ) | |
| # Fetching returns descending order (latest first), we need chronological for LLM prompt | |
| rows = cursor.fetchall() | |
| rows.reverse() | |
| history = [] | |
| for row in rows: | |
| history.append({"role": "user", "content": row["user_query"]}) | |
| history.append({"role": "assistant", "content": row["response"]}) | |
| return history | |
| except Exception as e: | |
| logger.error(f"Error getting messages for LLM context: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def get_messages_paginated(conv_id: int, start_row: int, end_row: int): | |
| """Return paginated messages.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| # Calculate pagination | |
| offset = start_row - 1 | |
| limit = end_row - start_row + 1 | |
| # Get total messages count | |
| cursor.execute(query_constants.COUNT_MESSAGES, (conv_id,)) | |
| total_messages = cursor.fetchone()["total"] | |
| # Get paginated messages | |
| cursor.execute( | |
| query_constants.GET_PAGINATED_MESSAGES, | |
| (conv_id, offset, limit) | |
| ) | |
| messages = [dict(row) for row in cursor.fetchall()] | |
| return { | |
| "total_messages": total_messages, | |
| "messages": messages | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting paginated messages: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def rename_conversation(conv_id: int, new_name: str): | |
| """Update a conversation's title.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(query_constants.RENAME_CONVERSATION, (new_name, conv_id)) | |
| updated = cursor.fetchone() | |
| conn.commit() | |
| return dict(updated) if updated else None | |
| except Exception as e: | |
| logger.error(f"Error renaming conversation: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def delete_conversation(conv_id: int): | |
| """Delete a conversation and all its messages.""" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| # Delete messages first (also handled by CASCADE, but explicit is safer) | |
| cursor.execute(query_constants.DELETE_MESSAGES_BY_CONVERSATION, (conv_id,)) | |
| # Delete the conversation itself | |
| cursor.execute(query_constants.DELETE_CONVERSATION, (conv_id,)) | |
| deleted = cursor.fetchone() | |
| conn.commit() | |
| return dict(deleted) if deleted else None | |
| except Exception as e: | |
| logger.error(f"Error deleting conversation: {e}") | |
| if conn: | |
| conn.rollback() | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |