| """ |
| Database connection and operations module for SquidLLM backend. |
| Handles all interactions with Supabase PostgreSQL database. |
| """ |
|
|
| import os |
| import psycopg2 |
| from psycopg2.extras import RealDictCursor |
| from typing import Optional, List, Dict, Any |
|
|
|
|
| class Database: |
| """Database connection handler for PostgreSQL/Supabase.""" |
| |
| def __init__(self): |
| """Initialize database handler. Connection is lazy-loaded on first use.""" |
| self.connection_string = os.getenv('RDS_LOGIN') |
| if not self.connection_string: |
| raise ValueError("RDS_LOGIN environment variable is not set") |
| self.conn = None |
| self._tables_created = False |
| |
| def _connect(self): |
| """Establish connection to the database.""" |
| try: |
| self.conn = psycopg2.connect( |
| self.connection_string, |
| connect_timeout=10, |
| keepalives=1, |
| keepalives_idle=30, |
| keepalives_interval=10, |
| keepalives_count=5 |
| ) |
| self.conn.autocommit = False |
| except psycopg2.Error as e: |
| raise ConnectionError(f"Failed to connect to database: {e}") |
| |
| def _create_tables(self): |
| """Create database tables if they don't exist.""" |
| if self._tables_created: |
| return |
| |
| cursor = self.conn.cursor() |
| |
| |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS prompts ( |
| id SERIAL PRIMARY KEY, |
| text TEXT NOT NULL, |
| note TEXT |
| ) |
| """) |
| |
| |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS responses ( |
| id SERIAL PRIMARY KEY, |
| prompt_id INTEGER NOT NULL REFERENCES prompts(id) ON DELETE CASCADE, |
| llm VARCHAR(255) NOT NULL, |
| response TEXT NOT NULL, |
| jailbroken BOOLEAN NOT NULL DEFAULT FALSE, |
| note TEXT |
| ) |
| """) |
| |
| self.conn.commit() |
| cursor.close() |
| self._tables_created = True |
| |
| def get_connection(self): |
| """Get a database connection. Connects lazily and creates tables if needed.""" |
| try: |
| if self.conn is None or self.conn.closed: |
| self._connect() |
| self._create_tables() |
| |
| cursor = self.conn.cursor() |
| cursor.execute("SELECT 1") |
| cursor.close() |
| except (psycopg2.OperationalError, psycopg2.InterfaceError): |
| |
| if self.conn and not self.conn.closed: |
| try: |
| self.conn.close() |
| except: |
| pass |
| self._connect() |
| self._create_tables() |
| return self.conn |
| |
| |
| |
| def insert_prompt(self, text: str, note: Optional[str] = None) -> int: |
| """ |
| Insert a new prompt into the database. |
| |
| Args: |
| text: The prompt text |
| note: Optional note about the prompt |
| |
| Returns: |
| The ID of the inserted prompt |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute( |
| "INSERT INTO prompts (text, note) VALUES (%s, %s) RETURNING id", |
| (text, note) |
| ) |
| prompt_id = cursor.fetchone()[0] |
| conn.commit() |
| return prompt_id |
| except psycopg2.Error as e: |
| conn.rollback() |
| raise Exception(f"Failed to insert prompt: {e}") |
| finally: |
| cursor.close() |
| |
| def get_prompt(self, prompt_id: int) -> Optional[Dict[str, Any]]: |
| """ |
| Get a prompt by ID. |
| |
| Args: |
| prompt_id: The ID of the prompt |
| |
| Returns: |
| Dictionary with prompt data or None if not found |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor(cursor_factory=RealDictCursor) |
| |
| try: |
| cursor.execute( |
| "SELECT id, text, note FROM prompts WHERE id = %s", |
| (prompt_id,) |
| ) |
| result = cursor.fetchone() |
| return dict(result) if result else None |
| finally: |
| cursor.close() |
| |
| def get_all_prompts(self, limit: Optional[int] = None, offset: int = 0) -> List[Dict[str, Any]]: |
| """ |
| Get all prompts from the database. |
| |
| Args: |
| limit: Maximum number of prompts to return |
| offset: Number of prompts to skip |
| |
| Returns: |
| List of dictionaries with prompt data |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor(cursor_factory=RealDictCursor) |
| |
| try: |
| query = "SELECT id, text, note FROM prompts ORDER BY id DESC" |
| if limit: |
| query += f" LIMIT {limit} OFFSET {offset}" |
| cursor.execute(query) |
| results = cursor.fetchall() |
| return [dict(row) for row in results] |
| finally: |
| cursor.close() |
| |
| def update_prompt_note(self, prompt_id: int, note: Optional[str]) -> bool: |
| """ |
| Update the note for a prompt. |
| |
| Args: |
| prompt_id: The ID of the prompt |
| note: The new note (can be None to clear) |
| |
| Returns: |
| True if update was successful, False if prompt not found |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute( |
| "UPDATE prompts SET note = %s WHERE id = %s", |
| (note, prompt_id) |
| ) |
| conn.commit() |
| return cursor.rowcount > 0 |
| except psycopg2.Error as e: |
| conn.rollback() |
| raise Exception(f"Failed to update prompt note: {e}") |
| finally: |
| cursor.close() |
| |
| |
| |
| def insert_response( |
| self, |
| prompt_id: int, |
| llm: str, |
| response: str, |
| jailbroken: bool = False, |
| note: Optional[str] = None |
| ) -> int: |
| """ |
| Insert a new response into the database. |
| |
| Args: |
| prompt_id: The ID of the associated prompt |
| llm: The name/identifier of the LLM |
| response: The response text |
| jailbroken: Whether the response indicates a jailbreak |
| note: Optional note about the response |
| |
| Returns: |
| The ID of the inserted response |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute( |
| """INSERT INTO responses (prompt_id, llm, response, jailbroken, note) |
| VALUES (%s, %s, %s, %s, %s) RETURNING id""", |
| (prompt_id, llm, response, jailbroken, note) |
| ) |
| response_id = cursor.fetchone()[0] |
| conn.commit() |
| return response_id |
| except psycopg2.Error as e: |
| conn.rollback() |
| raise Exception(f"Failed to insert response: {e}") |
| finally: |
| cursor.close() |
| |
| def insert_batch_responses( |
| self, |
| responses: List[Dict[str, Any]] |
| ) -> List[int]: |
| """ |
| Insert multiple responses in a batch. |
| |
| Args: |
| responses: List of dictionaries with keys: prompt_id, llm, response, jailbroken, note |
| |
| Returns: |
| List of inserted response IDs |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor() |
| inserted_ids = [] |
| |
| try: |
| for resp in responses: |
| cursor.execute( |
| """INSERT INTO responses (prompt_id, llm, response, jailbroken, note) |
| VALUES (%s, %s, %s, %s, %s) RETURNING id""", |
| ( |
| resp['prompt_id'], |
| resp['llm'], |
| resp['response'], |
| resp.get('jailbroken', False), |
| resp.get('note') |
| ) |
| ) |
| inserted_ids.append(cursor.fetchone()[0]) |
| |
| conn.commit() |
| return inserted_ids |
| except psycopg2.Error as e: |
| conn.rollback() |
| raise Exception(f"Failed to insert batch responses: {e}") |
| finally: |
| cursor.close() |
| |
| def get_response(self, response_id: int) -> Optional[Dict[str, Any]]: |
| """ |
| Get a response by ID. |
| |
| Args: |
| response_id: The ID of the response |
| |
| Returns: |
| Dictionary with response data or None if not found |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor(cursor_factory=RealDictCursor) |
| |
| try: |
| cursor.execute( |
| """SELECT id, prompt_id, llm, response, jailbroken, note |
| FROM responses WHERE id = %s""", |
| (response_id,) |
| ) |
| result = cursor.fetchone() |
| return dict(result) if result else None |
| finally: |
| cursor.close() |
| |
| def get_responses_by_prompt(self, prompt_id: int) -> List[Dict[str, Any]]: |
| """ |
| Get all responses for a specific prompt. |
| |
| Args: |
| prompt_id: The ID of the prompt |
| |
| Returns: |
| List of dictionaries with response data |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor(cursor_factory=RealDictCursor) |
| |
| try: |
| cursor.execute( |
| """SELECT id, prompt_id, llm, response, jailbroken, note |
| FROM responses WHERE prompt_id = %s ORDER BY id""", |
| (prompt_id,) |
| ) |
| results = cursor.fetchall() |
| return [dict(row) for row in results] |
| finally: |
| cursor.close() |
| |
| def get_all_responses( |
| self, |
| limit: Optional[int] = None, |
| offset: int = 0, |
| llm_filter: Optional[str] = None, |
| jailbroken_filter: Optional[str] = None, |
| prompt_id_filter: Optional[int] = None |
| ) -> List[Dict[str, Any]]: |
| """ |
| Get all responses from the database. |
| |
| Args: |
| limit: Maximum number of responses to return |
| offset: Number of responses to skip |
| llm_filter: Optional filter by LLM name |
| jailbroken_filter: Optional filter by jailbroken status ('true', 'false', or None for all) |
| prompt_id_filter: Optional filter by prompt ID |
| |
| Returns: |
| List of dictionaries with response data |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor(cursor_factory=RealDictCursor) |
| |
| try: |
| query = "SELECT id, prompt_id, llm, response, jailbroken, note FROM responses" |
| params = [] |
| conditions = [] |
| |
| if llm_filter: |
| conditions.append("llm = %s") |
| params.append(llm_filter) |
| |
| if jailbroken_filter is not None and jailbroken_filter != '': |
| |
| jailbroken_bool = jailbroken_filter.lower() == 'true' |
| conditions.append("jailbroken = %s") |
| params.append(jailbroken_bool) |
| |
| if prompt_id_filter: |
| conditions.append("prompt_id = %s") |
| params.append(prompt_id_filter) |
| |
| if conditions: |
| query += " WHERE " + " AND ".join(conditions) |
| |
| query += " ORDER BY id DESC" |
| |
| if limit: |
| query += f" LIMIT {limit} OFFSET {offset}" |
| |
| cursor.execute(query, params if params else None) |
| results = cursor.fetchall() |
| return [dict(row) for row in results] |
| finally: |
| cursor.close() |
| |
| def update_response_note(self, response_id: int, note: Optional[str]) -> bool: |
| """ |
| Update the note for a response. |
| |
| Args: |
| response_id: The ID of the response |
| note: The new note (can be None to clear) |
| |
| Returns: |
| True if update was successful, False if response not found |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute( |
| "UPDATE responses SET note = %s WHERE id = %s", |
| (note, response_id) |
| ) |
| conn.commit() |
| return cursor.rowcount > 0 |
| except psycopg2.Error as e: |
| conn.rollback() |
| raise Exception(f"Failed to update response note: {e}") |
| finally: |
| cursor.close() |
| |
| def update_response_jailbroken(self, response_id: int, jailbroken: bool) -> bool: |
| """ |
| Update the jailbroken status for a response. |
| |
| Args: |
| response_id: The ID of the response |
| jailbroken: The new jailbroken status |
| |
| Returns: |
| True if update was successful, False if response not found |
| """ |
| conn = self.get_connection() |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute( |
| "UPDATE responses SET jailbroken = %s WHERE id = %s", |
| (jailbroken, response_id) |
| ) |
| conn.commit() |
| return cursor.rowcount > 0 |
| except psycopg2.Error as e: |
| conn.rollback() |
| raise Exception(f"Failed to update response jailbroken status: {e}") |
| finally: |
| cursor.close() |
| |
| def close(self): |
| """Close the database connection.""" |
| if self.conn and not self.conn.closed: |
| self.conn.close() |
|
|