Spaces:
Sleeping
Sleeping
| # File: db/postgres.py | |
| import json | |
| import logging | |
| from contextlib import contextmanager | |
| from typing import List, Optional | |
| from psycopg2 import pool | |
| from core.config import settings | |
| from db.schemas import ( | |
| ChatMessageRecord, | |
| MessageRecord, | |
| CallRecord, | |
| ContactRecord, | |
| TimelineRecord, | |
| MediaRecord, | |
| NetworkGraphRecord | |
| ) | |
| logger = logging.getLogger("PostgresManager") | |
| class PostgresDatabase: | |
| """ | |
| Manages a threaded connection pool for PostgreSQL evidence storage. | |
| Handles raw evidence retrieval, persistent chat history, and analysis caching. | |
| """ | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(PostgresDatabase, cls).__new__(cls) | |
| cls._instance._initialize_pool() | |
| return cls._instance | |
| def _initialize_pool(self): | |
| """Initializes the thread-safe connection pool using core settings.""" | |
| try: | |
| self.connection_pool = pool.ThreadedConnectionPool( | |
| minconn=1, | |
| maxconn=20, | |
| host=settings.POSTGRES_HOST, | |
| port=settings.POSTGRES_PORT, | |
| database=settings.POSTGRES_DB, | |
| user=settings.POSTGRES_USER, | |
| password=settings.POSTGRES_PASSWORD | |
| ) | |
| logger.info("PostgreSQL connection pool initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize PostgreSQL pool: {e}") | |
| raise | |
| def get_connection(self): | |
| """Context manager for safe connection acquisition and release.""" | |
| conn = self.connection_pool.getconn() | |
| try: | |
| yield conn | |
| finally: | |
| self.connection_pool.putconn(conn) | |
| def initialize_tables(self): | |
| """Sets up persistent storage tables for analysis results and chat history.""" | |
| queries = [ | |
| """ | |
| CREATE TABLE IF NOT EXISTS nlp_analysis_results ( | |
| case_id TEXT, | |
| sender TEXT, | |
| risk_score_sum FLOAT, | |
| detected_behaviors TEXT[], | |
| message_count INTEGER, | |
| last_analyzed TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| PRIMARY KEY (case_id, sender) | |
| ); | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS ai_chat_history ( | |
| id SERIAL PRIMARY KEY, | |
| case_id TEXT NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS graph_cache ( | |
| case_id TEXT PRIMARY KEY, | |
| graph_data JSONB, | |
| last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| """ | |
| ] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| for query in queries: | |
| cursor.execute(query) | |
| conn.commit() | |
| logger.info("Forensic persistence tables verified.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize database tables: {e}") | |
| def get_image_binary(self, case_id: str, file_name: str) -> Optional[bytes]: | |
| """Retrieves raw binary image data (BYTEA) for vision analysis or streaming.""" | |
| query = "SELECT file_data FROM media_sharing WHERE case_id = %s AND file_name = %s LIMIT 1;" | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, file_name)) | |
| result = cursor.fetchone() | |
| return bytes(result[0]) if result and result[0] else None | |
| except Exception as e: | |
| logger.error(f"Error retrieving image binary: {e}") | |
| return None | |
| def get_messages(self, case_id: str, limit: int = 1000) -> List[MessageRecord]: | |
| """Retrieves message logs for a specific forensic case.""" | |
| query = "SELECT id, sender, receiver, timestamp, message FROM messages WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;" | |
| results = [] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, limit)) | |
| for row in cursor.fetchall(): | |
| results.append(MessageRecord( | |
| id=row[0], sender=row[1], receiver=row[2], | |
| timestamp=row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else "", | |
| message=row[4] | |
| )) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve messages: {e}") | |
| return [] | |
| def get_calls(self, case_id: str, limit: int = 1000) -> List[CallRecord]: | |
| """Retrieves call logs for a specific forensic case.""" | |
| query = "SELECT id, caller, receiver, timestamp, call_duration_seconds, call_type FROM calls WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;" | |
| results = [] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, limit)) | |
| for row in cursor.fetchall(): | |
| results.append(CallRecord( | |
| id=row[0], caller=row[1], receiver=row[2], | |
| timestamp=row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else "", | |
| call_duration_seconds=row[4], call_type=row[5] | |
| )) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve calls: {e}") | |
| return [] | |
| def get_contacts(self, case_id: str) -> List[ContactRecord]: | |
| """Retrieves the contact list associated with a case.""" | |
| query = "SELECT id, name, phone FROM contacts WHERE case_id = %s ORDER BY name ASC;" | |
| results = [] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id,)) | |
| for row in cursor.fetchall(): | |
| results.append(ContactRecord(id=row[0], name=row[1], phone=row[2])) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve contacts: {e}") | |
| return [] | |
| def get_timeline(self, case_id: str, limit: int = 1000) -> List[TimelineRecord]: | |
| """Retrieves a chronological timeline of events for a case.""" | |
| query = "SELECT id, timestamp, event_type, user_name, details FROM timeline WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;" | |
| results = [] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, limit)) | |
| for row in cursor.fetchall(): | |
| results.append(TimelineRecord( | |
| id=row[0], | |
| timestamp=row[1].strftime("%Y-%m-%d %H:%M:%S") if row[1] else "", | |
| event_type=row[2], user_name=row[3], details=row[4] | |
| )) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve timeline: {e}") | |
| return [] | |
| def get_media_records(self, case_id: str, limit: int = 1000) -> List[MediaRecord]: | |
| """Retrieves sharing metadata for media files in a case.""" | |
| query = "SELECT id, sender, receiver, timestamp, file_path, file_name, file_type FROM media_sharing WHERE case_id = %s ORDER BY timestamp ASC LIMIT %s;" | |
| results = [] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, limit)) | |
| for row in cursor.fetchall(): | |
| results.append(MediaRecord( | |
| id=row[0], sender=row[1], receiver=row[2], | |
| timestamp=row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else "", | |
| file_path=row[4], file_name=row[5], file_type=row[6] | |
| )) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve media records: {e}") | |
| return [] | |
| def save_network_graph(self, case_id: str, graph_data: dict): | |
| """Caches the network topology JSON into the database.""" | |
| query = """ | |
| INSERT INTO graph_cache (case_id, graph_data, last_updated) | |
| VALUES (%s, %s, CURRENT_TIMESTAMP) | |
| ON CONFLICT (case_id) DO UPDATE SET | |
| graph_data = EXCLUDED.graph_data, | |
| last_updated = CURRENT_TIMESTAMP; | |
| """ | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, json.dumps(graph_data))) | |
| conn.commit() | |
| logger.info(f"Network graph cached for case: {case_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to cache network graph: {e}") | |
| def get_network_graph(self, case_id: str) -> Optional[NetworkGraphRecord]: | |
| """Retrieves the cached network topology for visualization.""" | |
| query = "SELECT graph_data FROM graph_cache WHERE case_id = %s;" | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id,)) | |
| result = cursor.fetchone() | |
| return NetworkGraphRecord(**result[0]) if result and result[0] else None | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve graph cache: {e}") | |
| return None | |
| def save_chat_message(self, case_id: str, role: str, content: str): | |
| """Persists an AI or Human message to the case history.""" | |
| query = "INSERT INTO ai_chat_history (case_id, role, content) VALUES (%s, %s, %s);" | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id, role, content)) | |
| conn.commit() | |
| except Exception as e: | |
| logger.error(f"Failed to save chat message: {e}") | |
| def get_chat_history(self, case_id: str) -> List[ChatMessageRecord]: | |
| """Retrieves the full conversational history for a case.""" | |
| query = "SELECT role, content, timestamp FROM ai_chat_history WHERE case_id = %s ORDER BY timestamp ASC;" | |
| history = [] | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id,)) | |
| for row in cursor.fetchall(): | |
| history.append(ChatMessageRecord( | |
| role=row[0], content=row[1], | |
| timestamp=row[2].strftime("%Y-%m-%d %H:%M:%S") | |
| )) | |
| return history | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve chat history: {e}") | |
| return [] | |
| def clear_chat_history(self, case_id: str): | |
| """Wipes the conversation history for a specific case context.""" | |
| query = "DELETE FROM ai_chat_history WHERE case_id = %s;" | |
| try: | |
| with self.get_connection() as conn: | |
| with conn.cursor() as cursor: | |
| cursor.execute(query, (case_id,)) | |
| conn.commit() | |
| logger.info(f"Chat history cleared for case: {case_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to clear chat history: {e}") | |
| def close_all(self): | |
| """Safely terminates all connections in the pool.""" | |
| if self.connection_pool: | |
| self.connection_pool.closeall() | |
| logger.info("PostgreSQL pool closed.") | |
| db_manager = PostgresDatabase() | |
| def get_db(): return db_manager |