| import os |
| import psycopg2 |
| import uuid |
| import datetime |
| import logging |
| import json |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class PostgresMemoryManager: |
| def __init__(self): |
| self.db_url = os.getenv("NEON_DATABASE_URL") |
| self.conn = None |
| self._connect() |
|
|
| def _connect(self): |
| if not self.db_url: |
| logger.error("No NEON_DATABASE_URL provided. Cannot connect to DB.") |
| return |
|
|
| try: |
| self.conn = psycopg2.connect(self.db_url) |
| self._initialize_tables() |
| logger.info("Connected to PostgreSQL DB.") |
| except Exception as e: |
| logger.error(f"Failed to connect to database: {e}.") |
| self.conn = None |
|
|
| def _ensure_connection(self): |
| if self.db_url: |
| if self.conn is None or self.conn.closed != 0: |
| self._connect() |
|
|
| def _initialize_tables(self): |
| if not self.conn: |
| return |
| try: |
| with self.conn.cursor() as cur: |
| cur.execute(""" |
| CREATE TABLE IF NOT EXISTS sessions ( |
| id VARCHAR(255) PRIMARY KEY, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| metadata TEXT |
| ); |
| """) |
| cur.execute(""" |
| ALTER TABLE sessions ADD COLUMN IF NOT EXISTS ip_address TEXT; |
| ALTER TABLE sessions ADD COLUMN IF NOT EXISTS location_data TEXT; |
| """) |
| cur.execute(""" |
| CREATE TABLE IF NOT EXISTS conversations ( |
| id VARCHAR(255) PRIMARY KEY, |
| session_id VARCHAR(255) REFERENCES sessions(id), |
| user_message TEXT, |
| ai_response TEXT, |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
| ); |
| """) |
| self.conn.commit() |
| except Exception as e: |
| logger.error(f"Error initializing tables: {e}") |
| try: |
| self.conn.rollback() |
| except Exception: |
| pass |
|
|
| def create_session(self, metadata="", ip_address=None, location_data=None): |
| self._ensure_connection() |
| session_id = str(uuid.uuid4()) |
| location_data_str = json.dumps(location_data) if location_data else None |
|
|
| if self.conn: |
| try: |
| with self.conn.cursor() as cur: |
| cur.execute( |
| "INSERT INTO sessions (id, metadata, ip_address, location_data) VALUES (%s, %s, %s, %s)", |
| (session_id, metadata, ip_address, location_data_str) |
| ) |
| self.conn.commit() |
| except Exception as e: |
| logger.error(f"DB Error creating session: {e}") |
| try: |
| self.conn.rollback() |
| except Exception: |
| self.conn = None |
| return session_id |
|
|
| def add_message(self, session_id, user_message, ai_response): |
| self._ensure_connection() |
| conv_id = str(uuid.uuid4()) |
| if self.conn: |
| try: |
| with self.conn.cursor() as cur: |
| cur.execute( |
| "INSERT INTO conversations (id, session_id, user_message, ai_response) VALUES (%s, %s, %s, %s)", |
| (conv_id, session_id, user_message, ai_response) |
| ) |
| self.conn.commit() |
| except Exception as e: |
| logger.error(f"DB Error adding message: {e}") |
| try: |
| self.conn.rollback() |
| except Exception: |
| self.conn = None |
|
|
| def get_history(self, session_id): |
| self._ensure_connection() |
| if self.conn: |
| try: |
| with self.conn.cursor() as cur: |
| cur.execute( |
| "SELECT user_message, ai_response, timestamp FROM conversations WHERE session_id = %s ORDER BY timestamp ASC", |
| (session_id,) |
| ) |
| rows = cur.fetchall() |
| return [{"user_message": row[0], "ai_response": row[1], "timestamp": row[2].isoformat()} for row in rows] |
| except Exception as e: |
| logger.error(f"DB Error getting history: {e}") |
| try: |
| self.conn.rollback() |
| except Exception: |
| self.conn = None |
| |
| return [] |
|
|