TraceIntel / db /postgres.py
Siddhant Belkhede
Deploying
2b13511
# File: db/postgres.py
import json
import logging
from contextlib import contextmanager
from typing import List, Optional
from psycopg2 import pool
from core.config import settings
from db.schemas import (
ChatMessageRecord,
MessageRecord,
CallRecord,
ContactRecord,
TimelineRecord,
MediaRecord,
NetworkGraphRecord
)
logger = logging.getLogger("PostgresManager")
class PostgresDatabase:
"""
Manages a threaded connection pool for PostgreSQL evidence storage.
Handles raw evidence retrieval, persistent chat history, and analysis caching.
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PostgresDatabase, cls).__new__(cls)
cls._instance._initialize_pool()
return cls._instance
def _initialize_pool(self):
"""Initializes the thread-safe connection pool using core settings."""
try:
self.connection_pool = pool.ThreadedConnectionPool(
minconn=1,
maxconn=20,
host=settings.POSTGRES_HOST,
port=settings.POSTGRES_PORT,
database=settings.POSTGRES_DB,
user=settings.POSTGRES_USER,
password=settings.POSTGRES_PASSWORD
)
logger.info("PostgreSQL connection pool initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL pool: {e}")
raise
@contextmanager
def get_connection(self):
"""Context manager for safe connection acquisition and release."""
conn = self.connection_pool.getconn()
try:
yield conn
finally:
self.connection_pool.putconn(conn)
def initialize_tables(self):
"""Sets up persistent storage tables for analysis results and chat history."""
queries = [
"""
CREATE TABLE IF NOT EXISTS nlp_analysis_results (
case_id TEXT,
sender TEXT,
risk_score_sum FLOAT,
detected_behaviors TEXT[],
message_count INTEGER,
last_analyzed TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (case_id, sender)
);
""",
"""
CREATE TABLE IF NOT EXISTS ai_chat_history (
id SERIAL PRIMARY KEY,
case_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""",
"""
CREATE TABLE IF NOT EXISTS graph_cache (
case_id TEXT PRIMARY KEY,
graph_data JSONB,
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
]
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
for query in queries:
cursor.execute(query)
conn.commit()
logger.info("Forensic persistence tables verified.")
except Exception as e:
logger.error(f"Failed to initialize database tables: {e}")
def get_image_binary(self, case_id: str, file_name: str) -> Optional[bytes]:
"""Retrieves raw binary image data (BYTEA) for vision analysis or streaming."""
query = "SELECT file_data FROM media_sharing WHERE case_id = %s AND file_name = %s LIMIT 1;"
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, file_name))
result = cursor.fetchone()
return bytes(result[0]) if result and result[0] else None
except Exception as e:
logger.error(f"Error retrieving image binary: {e}")
return None
def get_messages(self, case_id: str, limit: int = 1000) -> List[MessageRecord]:
"""Retrieves message logs for a specific forensic case."""
query = "SELECT id, sender, receiver, timestamp, message FROM messages WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;"
results = []
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, limit))
for row in cursor.fetchall():
results.append(MessageRecord(
id=row[0], sender=row[1], receiver=row[2],
timestamp=row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else "",
message=row[4]
))
return results
except Exception as e:
logger.error(f"Failed to retrieve messages: {e}")
return []
def get_calls(self, case_id: str, limit: int = 1000) -> List[CallRecord]:
"""Retrieves call logs for a specific forensic case."""
query = "SELECT id, caller, receiver, timestamp, call_duration_seconds, call_type FROM calls WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;"
results = []
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, limit))
for row in cursor.fetchall():
results.append(CallRecord(
id=row[0], caller=row[1], receiver=row[2],
timestamp=row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else "",
call_duration_seconds=row[4], call_type=row[5]
))
return results
except Exception as e:
logger.error(f"Failed to retrieve calls: {e}")
return []
def get_contacts(self, case_id: str) -> List[ContactRecord]:
"""Retrieves the contact list associated with a case."""
query = "SELECT id, name, phone FROM contacts WHERE case_id = %s ORDER BY name ASC;"
results = []
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id,))
for row in cursor.fetchall():
results.append(ContactRecord(id=row[0], name=row[1], phone=row[2]))
return results
except Exception as e:
logger.error(f"Failed to retrieve contacts: {e}")
return []
def get_timeline(self, case_id: str, limit: int = 1000) -> List[TimelineRecord]:
"""Retrieves a chronological timeline of events for a case."""
query = "SELECT id, timestamp, event_type, user_name, details FROM timeline WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;"
results = []
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, limit))
for row in cursor.fetchall():
results.append(TimelineRecord(
id=row[0],
timestamp=row[1].strftime("%Y-%m-%d %H:%M:%S") if row[1] else "",
event_type=row[2], user_name=row[3], details=row[4]
))
return results
except Exception as e:
logger.error(f"Failed to retrieve timeline: {e}")
return []
def get_media_records(self, case_id: str, limit: int = 1000) -> List[MediaRecord]:
"""Retrieves sharing metadata for media files in a case."""
query = "SELECT id, sender, receiver, timestamp, file_path, file_name, file_type FROM media_sharing WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;"
results = []
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, limit))
for row in cursor.fetchall():
results.append(MediaRecord(
id=row[0], sender=row[1], receiver=row[2],
timestamp=row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else "",
file_path=row[4], file_name=row[5], file_type=row[6]
))
return results
except Exception as e:
logger.error(f"Failed to retrieve media records: {e}")
return []
def save_network_graph(self, case_id: str, graph_data: dict):
"""Caches the network topology JSON into the database."""
query = """
INSERT INTO graph_cache (case_id, graph_data, last_updated)
VALUES (%s, %s, CURRENT_TIMESTAMP)
ON CONFLICT (case_id) DO UPDATE SET
graph_data = EXCLUDED.graph_data,
last_updated = CURRENT_TIMESTAMP;
"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, json.dumps(graph_data)))
conn.commit()
logger.info(f"Network graph cached for case: {case_id}")
except Exception as e:
logger.error(f"Failed to cache network graph: {e}")
def get_network_graph(self, case_id: str) -> Optional[NetworkGraphRecord]:
"""Retrieves the cached network topology for visualization."""
query = "SELECT graph_data FROM graph_cache WHERE case_id = %s;"
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id,))
result = cursor.fetchone()
return NetworkGraphRecord(**result[0]) if result and result[0] else None
except Exception as e:
logger.error(f"Failed to retrieve graph cache: {e}")
return None
def save_chat_message(self, case_id: str, role: str, content: str):
"""Persists an AI or Human message to the case history."""
query = "INSERT INTO ai_chat_history (case_id, role, content) VALUES (%s, %s, %s);"
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id, role, content))
conn.commit()
except Exception as e:
logger.error(f"Failed to save chat message: {e}")
def get_chat_history(self, case_id: str) -> List[ChatMessageRecord]:
"""Retrieves the full conversational history for a case."""
query = "SELECT role, content, timestamp FROM ai_chat_history WHERE case_id = %s ORDER BY timestamp ASC;"
history = []
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id,))
for row in cursor.fetchall():
history.append(ChatMessageRecord(
role=row[0], content=row[1],
timestamp=row[2].strftime("%Y-%m-%d %H:%M:%S")
))
return history
except Exception as e:
logger.error(f"Failed to retrieve chat history: {e}")
return []
def clear_chat_history(self, case_id: str):
"""Wipes the conversation history for a specific case context."""
query = "DELETE FROM ai_chat_history WHERE case_id = %s;"
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, (case_id,))
conn.commit()
logger.info(f"Chat history cleared for case: {case_id}")
except Exception as e:
logger.error(f"Failed to clear chat history: {e}")
def close_all(self):
"""Safely terminates all connections in the pool."""
if self.connection_pool:
self.connection_pool.closeall()
logger.info("PostgreSQL pool closed.")
db_manager = PostgresDatabase()
def get_db(): return db_manager