tentacool / database.py
dhruv575
Filtering for responses
7f082cb
"""
Database connection and operations module for SquidLLM backend.
Handles all interactions with Supabase PostgreSQL database.
"""
import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import Optional, List, Dict, Any
class Database:
"""Database connection handler for PostgreSQL/Supabase."""
def __init__(self):
"""Initialize database handler. Connection is lazy-loaded on first use."""
self.connection_string = os.getenv('RDS_LOGIN')
if not self.connection_string:
raise ValueError("RDS_LOGIN environment variable is not set")
self.conn = None
self._tables_created = False
def _connect(self):
"""Establish connection to the database."""
try:
self.conn = psycopg2.connect(
self.connection_string,
connect_timeout=10,
keepalives=1,
keepalives_idle=30,
keepalives_interval=10,
keepalives_count=5
)
self.conn.autocommit = False
except psycopg2.Error as e:
raise ConnectionError(f"Failed to connect to database: {e}")
def _create_tables(self):
"""Create database tables if they don't exist."""
if self._tables_created:
return
cursor = self.conn.cursor()
# Create Prompts table
cursor.execute("""
CREATE TABLE IF NOT EXISTS prompts (
id SERIAL PRIMARY KEY,
text TEXT NOT NULL,
note TEXT
)
""")
# Create Responses table
cursor.execute("""
CREATE TABLE IF NOT EXISTS responses (
id SERIAL PRIMARY KEY,
prompt_id INTEGER NOT NULL REFERENCES prompts(id) ON DELETE CASCADE,
llm VARCHAR(255) NOT NULL,
response TEXT NOT NULL,
jailbroken BOOLEAN NOT NULL DEFAULT FALSE,
note TEXT
)
""")
self.conn.commit()
cursor.close()
self._tables_created = True
def get_connection(self):
"""Get a database connection. Connects lazily and creates tables if needed."""
try:
if self.conn is None or self.conn.closed:
self._connect()
self._create_tables()
# Test the connection is still alive
cursor = self.conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
except (psycopg2.OperationalError, psycopg2.InterfaceError):
# Connection is dead, reconnect
if self.conn and not self.conn.closed:
try:
self.conn.close()
except:
pass
self._connect()
self._create_tables()
return self.conn
# ========== PROMPT OPERATIONS ==========
def insert_prompt(self, text: str, note: Optional[str] = None) -> int:
"""
Insert a new prompt into the database.
Args:
text: The prompt text
note: Optional note about the prompt
Returns:
The ID of the inserted prompt
"""
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
"INSERT INTO prompts (text, note) VALUES (%s, %s) RETURNING id",
(text, note)
)
prompt_id = cursor.fetchone()[0]
conn.commit()
return prompt_id
except psycopg2.Error as e:
conn.rollback()
raise Exception(f"Failed to insert prompt: {e}")
finally:
cursor.close()
def get_prompt(self, prompt_id: int) -> Optional[Dict[str, Any]]:
"""
Get a prompt by ID.
Args:
prompt_id: The ID of the prompt
Returns:
Dictionary with prompt data or None if not found
"""
conn = self.get_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
try:
cursor.execute(
"SELECT id, text, note FROM prompts WHERE id = %s",
(prompt_id,)
)
result = cursor.fetchone()
return dict(result) if result else None
finally:
cursor.close()
def get_all_prompts(self, limit: Optional[int] = None, offset: int = 0) -> List[Dict[str, Any]]:
"""
Get all prompts from the database.
Args:
limit: Maximum number of prompts to return
offset: Number of prompts to skip
Returns:
List of dictionaries with prompt data
"""
conn = self.get_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
try:
query = "SELECT id, text, note FROM prompts ORDER BY id DESC"
if limit:
query += f" LIMIT {limit} OFFSET {offset}"
cursor.execute(query)
results = cursor.fetchall()
return [dict(row) for row in results]
finally:
cursor.close()
def update_prompt_note(self, prompt_id: int, note: Optional[str]) -> bool:
"""
Update the note for a prompt.
Args:
prompt_id: The ID of the prompt
note: The new note (can be None to clear)
Returns:
True if update was successful, False if prompt not found
"""
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
"UPDATE prompts SET note = %s WHERE id = %s",
(note, prompt_id)
)
conn.commit()
return cursor.rowcount > 0
except psycopg2.Error as e:
conn.rollback()
raise Exception(f"Failed to update prompt note: {e}")
finally:
cursor.close()
# ========== RESPONSE OPERATIONS ==========
def insert_response(
self,
prompt_id: int,
llm: str,
response: str,
jailbroken: bool = False,
note: Optional[str] = None
) -> int:
"""
Insert a new response into the database.
Args:
prompt_id: The ID of the associated prompt
llm: The name/identifier of the LLM
response: The response text
jailbroken: Whether the response indicates a jailbreak
note: Optional note about the response
Returns:
The ID of the inserted response
"""
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
"""INSERT INTO responses (prompt_id, llm, response, jailbroken, note)
VALUES (%s, %s, %s, %s, %s) RETURNING id""",
(prompt_id, llm, response, jailbroken, note)
)
response_id = cursor.fetchone()[0]
conn.commit()
return response_id
except psycopg2.Error as e:
conn.rollback()
raise Exception(f"Failed to insert response: {e}")
finally:
cursor.close()
def insert_batch_responses(
self,
responses: List[Dict[str, Any]]
) -> List[int]:
"""
Insert multiple responses in a batch.
Args:
responses: List of dictionaries with keys: prompt_id, llm, response, jailbroken, note
Returns:
List of inserted response IDs
"""
conn = self.get_connection()
cursor = conn.cursor()
inserted_ids = []
try:
for resp in responses:
cursor.execute(
"""INSERT INTO responses (prompt_id, llm, response, jailbroken, note)
VALUES (%s, %s, %s, %s, %s) RETURNING id""",
(
resp['prompt_id'],
resp['llm'],
resp['response'],
resp.get('jailbroken', False),
resp.get('note')
)
)
inserted_ids.append(cursor.fetchone()[0])
conn.commit()
return inserted_ids
except psycopg2.Error as e:
conn.rollback()
raise Exception(f"Failed to insert batch responses: {e}")
finally:
cursor.close()
def get_response(self, response_id: int) -> Optional[Dict[str, Any]]:
"""
Get a response by ID.
Args:
response_id: The ID of the response
Returns:
Dictionary with response data or None if not found
"""
conn = self.get_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
try:
cursor.execute(
"""SELECT id, prompt_id, llm, response, jailbroken, note
FROM responses WHERE id = %s""",
(response_id,)
)
result = cursor.fetchone()
return dict(result) if result else None
finally:
cursor.close()
def get_responses_by_prompt(self, prompt_id: int) -> List[Dict[str, Any]]:
"""
Get all responses for a specific prompt.
Args:
prompt_id: The ID of the prompt
Returns:
List of dictionaries with response data
"""
conn = self.get_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
try:
cursor.execute(
"""SELECT id, prompt_id, llm, response, jailbroken, note
FROM responses WHERE prompt_id = %s ORDER BY id""",
(prompt_id,)
)
results = cursor.fetchall()
return [dict(row) for row in results]
finally:
cursor.close()
def get_all_responses(
self,
limit: Optional[int] = None,
offset: int = 0,
llm_filter: Optional[str] = None,
jailbroken_filter: Optional[str] = None,
prompt_id_filter: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Get all responses from the database.
Args:
limit: Maximum number of responses to return
offset: Number of responses to skip
llm_filter: Optional filter by LLM name
jailbroken_filter: Optional filter by jailbroken status ('true', 'false', or None for all)
prompt_id_filter: Optional filter by prompt ID
Returns:
List of dictionaries with response data
"""
conn = self.get_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
try:
query = "SELECT id, prompt_id, llm, response, jailbroken, note FROM responses"
params = []
conditions = []
if llm_filter:
conditions.append("llm = %s")
params.append(llm_filter)
if jailbroken_filter is not None and jailbroken_filter != '':
# Convert string to boolean
jailbroken_bool = jailbroken_filter.lower() == 'true'
conditions.append("jailbroken = %s")
params.append(jailbroken_bool)
if prompt_id_filter:
conditions.append("prompt_id = %s")
params.append(prompt_id_filter)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY id DESC"
if limit:
query += f" LIMIT {limit} OFFSET {offset}"
cursor.execute(query, params if params else None)
results = cursor.fetchall()
return [dict(row) for row in results]
finally:
cursor.close()
def update_response_note(self, response_id: int, note: Optional[str]) -> bool:
"""
Update the note for a response.
Args:
response_id: The ID of the response
note: The new note (can be None to clear)
Returns:
True if update was successful, False if response not found
"""
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
"UPDATE responses SET note = %s WHERE id = %s",
(note, response_id)
)
conn.commit()
return cursor.rowcount > 0
except psycopg2.Error as e:
conn.rollback()
raise Exception(f"Failed to update response note: {e}")
finally:
cursor.close()
def update_response_jailbroken(self, response_id: int, jailbroken: bool) -> bool:
"""
Update the jailbroken status for a response.
Args:
response_id: The ID of the response
jailbroken: The new jailbroken status
Returns:
True if update was successful, False if response not found
"""
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
"UPDATE responses SET jailbroken = %s WHERE id = %s",
(jailbroken, response_id)
)
conn.commit()
return cursor.rowcount > 0
except psycopg2.Error as e:
conn.rollback()
raise Exception(f"Failed to update response jailbroken status: {e}")
finally:
cursor.close()
def close(self):
"""Close the database connection."""
if self.conn and not self.conn.closed:
self.conn.close()