| """ |
| Database integration for Streamlit application. |
| |
| This module provides functions to interact with the database for the Streamlit frontend. |
| It wraps the async database functions in sync functions for Streamlit compatibility. |
| """ |
| import os |
| import asyncio |
| import pandas as pd |
| from typing import List, Dict, Any, Optional, Union, Tuple |
| from datetime import datetime, timedelta |
|
|
| from sqlalchemy.orm import sessionmaker |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession |
|
|
| |
| from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory |
| from src.models.indicator import Indicator, IndicatorType |
| from src.models.dark_web_content import DarkWebContent, DarkWebMention, ContentType, ContentStatus |
| from src.models.alert import Alert, AlertStatus, AlertCategory |
| from src.models.report import Report, ReportType, ReportStatus |
|
|
| |
| from src.api.services.dark_web_content_service import ( |
| create_content, get_content_by_id, get_contents, count_contents, |
| create_mention, get_mentions, create_threat_from_content |
| ) |
| from src.api.services.alert_service import ( |
| create_alert, get_alert_by_id, get_alerts, count_alerts, |
| update_alert_status, mark_alert_as_read, get_alert_counts_by_severity |
| ) |
| from src.api.services.threat_service import ( |
| create_threat, get_threat_by_id, get_threats, count_threats, |
| update_threat, add_indicator_to_threat, get_threat_statistics |
| ) |
| from src.api.services.report_service import ( |
| create_report, get_report_by_id, get_reports, count_reports, |
| update_report, add_threat_to_report, publish_report |
| ) |
|
|
| |
| from src.api.schemas import PaginationParams |
|
|
| |
| db_url = os.getenv("DATABASE_URL", "") |
| if db_url.startswith("postgresql://"): |
| |
| if "?" in db_url: |
| base_url, params = db_url.split("?", 1) |
| param_list = params.split("&") |
| filtered_params = [p for p in param_list if not p.startswith("sslmode=")] |
| if filtered_params: |
| db_url = f"{base_url}?{'&'.join(filtered_params)}" |
| else: |
| db_url = base_url |
| |
| ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) |
| else: |
| ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres" |
|
|
| |
| engine = create_async_engine( |
| ASYNC_DATABASE_URL, |
| echo=False, |
| future=True, |
| pool_size=5, |
| max_overflow=10 |
| ) |
|
|
| |
| async_session = sessionmaker( |
| engine, |
| class_=AsyncSession, |
| expire_on_commit=False |
| ) |
|
|
|
|
| def run_async(coro): |
| """Run an async function in a sync context.""" |
| try: |
| loop = asyncio.get_event_loop() |
| except RuntimeError: |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| |
| return loop.run_until_complete(coro) |
|
|
|
|
| async def get_session(): |
| """Get an async database session.""" |
| async with async_session() as session: |
| yield session |
|
|
|
|
| def get_db_session(): |
| """Get a database session for use in Streamlit.""" |
| try: |
| session_gen = get_session().__aiter__() |
| return run_async(session_gen.__anext__()) |
| except StopAsyncIteration: |
| return None |
|
|
|
|
| async def get_async_session(): |
| """ |
| Async context manager for database sessions. |
| |
| Usage: |
| async with get_async_session() as session: |
| # Use session here |
| """ |
| session = async_session() |
| try: |
| yield session |
| await session.commit() |
| except Exception as e: |
| await session.rollback() |
| raise e |
| finally: |
| await session.close() |
|
|
|
|
| |
| def get_dark_web_contents( |
| page: int = 1, |
| size: int = 10, |
| content_type: Optional[List[ContentType]] = None, |
| content_status: Optional[List[ContentStatus]] = None, |
| source_name: Optional[str] = None, |
| search_query: Optional[str] = None, |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> pd.DataFrame: |
| """ |
| Get dark web contents as a DataFrame. |
| |
| Args: |
| page: Page number |
| size: Page size |
| content_type: Filter by content type |
| content_status: Filter by content status |
| source_name: Filter by source name |
| search_query: Search in title and content |
| from_date: Filter by scraped_at >= from_date |
| to_date: Filter by scraped_at <= to_date |
| |
| Returns: |
| pd.DataFrame: DataFrame with dark web contents |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return pd.DataFrame() |
| |
| contents = run_async(get_contents( |
| db=session, |
| pagination=PaginationParams(page=page, size=size), |
| content_type=content_type, |
| content_status=content_status, |
| source_name=source_name, |
| search_query=search_query, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
| |
| if not contents: |
| return pd.DataFrame() |
| |
| |
| data = [] |
| for content in contents: |
| data.append({ |
| "id": content.id, |
| "url": content.url, |
| "title": content.title, |
| "content_type": content.content_type.value if content.content_type else None, |
| "content_status": content.content_status.value if content.content_status else None, |
| "source_name": content.source_name, |
| "source_type": content.source_type, |
| "language": content.language, |
| "scraped_at": content.scraped_at, |
| "relevance_score": content.relevance_score, |
| "sentiment_score": content.sentiment_score, |
| }) |
| |
| return pd.DataFrame(data) |
|
|
|
|
| def add_dark_web_content( |
| url: str, |
| content: str, |
| title: Optional[str] = None, |
| content_type: ContentType = ContentType.OTHER, |
| source_name: Optional[str] = None, |
| source_type: Optional[str] = None, |
| ) -> Optional[DarkWebContent]: |
| """ |
| Add a new dark web content. |
| |
| Args: |
| url: URL of the content |
| content: Text content |
| title: Title of the content |
| content_type: Type of content |
| source_name: Name of the source |
| source_type: Type of source |
| |
| Returns: |
| Optional[DarkWebContent]: Created content or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(create_content( |
| db=session, |
| url=url, |
| content=content, |
| title=title, |
| content_type=content_type, |
| source_name=source_name, |
| source_type=source_type, |
| )) |
|
|
|
|
| def get_dark_web_mentions( |
| page: int = 1, |
| size: int = 10, |
| keyword: Optional[str] = None, |
| content_id: Optional[int] = None, |
| is_verified: Optional[bool] = None, |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> pd.DataFrame: |
| """ |
| Get dark web mentions as a DataFrame. |
| |
| Args: |
| page: Page number |
| size: Page size |
| keyword: Filter by keyword |
| content_id: Filter by content ID |
| is_verified: Filter by verification status |
| from_date: Filter by created_at >= from_date |
| to_date: Filter by created_at <= to_date |
| |
| Returns: |
| pd.DataFrame: DataFrame with dark web mentions |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return pd.DataFrame() |
| |
| mentions = run_async(get_mentions( |
| db=session, |
| pagination=PaginationParams(page=page, size=size), |
| keyword=keyword, |
| content_id=content_id, |
| is_verified=is_verified, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
| |
| if not mentions: |
| return pd.DataFrame() |
| |
| |
| data = [] |
| for mention in mentions: |
| data.append({ |
| "id": mention.id, |
| "content_id": mention.content_id, |
| "keyword": mention.keyword, |
| "snippet": mention.snippet, |
| "mention_type": mention.mention_type, |
| "confidence": mention.confidence, |
| "is_verified": mention.is_verified, |
| "created_at": mention.created_at, |
| }) |
| |
| return pd.DataFrame(data) |
|
|
|
|
| def add_dark_web_mention( |
| content_id: int, |
| keyword: str, |
| context: Optional[str] = None, |
| snippet: Optional[str] = None, |
| ) -> Optional[DarkWebMention]: |
| """ |
| Add a new dark web mention. |
| |
| Args: |
| content_id: ID of the content where the mention was found |
| keyword: Keyword that was mentioned |
| context: Text surrounding the mention |
| snippet: Extract of text containing the mention |
| |
| Returns: |
| Optional[DarkWebMention]: Created mention or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(create_mention( |
| db=session, |
| content_id=content_id, |
| keyword=keyword, |
| context=context, |
| snippet=snippet, |
| )) |
|
|
|
|
| |
| def get_alerts_df( |
| page: int = 1, |
| size: int = 10, |
| severity: Optional[List[ThreatSeverity]] = None, |
| status: Optional[List[AlertStatus]] = None, |
| category: Optional[List[AlertCategory]] = None, |
| is_read: Optional[bool] = None, |
| search_query: Optional[str] = None, |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> pd.DataFrame: |
| """ |
| Get alerts as a DataFrame. |
| |
| Args: |
| page: Page number |
| size: Page size |
| severity: Filter by severity |
| status: Filter by status |
| category: Filter by category |
| is_read: Filter by read status |
| search_query: Search in title and description |
| from_date: Filter by generated_at >= from_date |
| to_date: Filter by generated_at <= to_date |
| |
| Returns: |
| pd.DataFrame: DataFrame with alerts |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return pd.DataFrame() |
| |
| alerts = run_async(get_alerts( |
| db=session, |
| pagination=PaginationParams(page=page, size=size), |
| severity=severity, |
| status=status, |
| category=category, |
| is_read=is_read, |
| search_query=search_query, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
| |
| if not alerts: |
| return pd.DataFrame() |
| |
| |
| data = [] |
| for alert in alerts: |
| data.append({ |
| "id": alert.id, |
| "title": alert.title, |
| "description": alert.description, |
| "severity": alert.severity.value if alert.severity else None, |
| "status": alert.status.value if alert.status else None, |
| "category": alert.category.value if alert.category else None, |
| "generated_at": alert.generated_at, |
| "source_url": alert.source_url, |
| "is_read": alert.is_read, |
| "threat_id": alert.threat_id, |
| "mention_id": alert.mention_id, |
| "assigned_to_id": alert.assigned_to_id, |
| "action_taken": alert.action_taken, |
| "resolved_at": alert.resolved_at, |
| }) |
| |
| return pd.DataFrame(data) |
|
|
|
|
| def add_alert( |
| title: str, |
| description: str, |
| severity: ThreatSeverity, |
| category: AlertCategory, |
| source_url: Optional[str] = None, |
| threat_id: Optional[int] = None, |
| mention_id: Optional[int] = None, |
| ) -> Optional[Alert]: |
| """ |
| Add a new alert. |
| |
| Args: |
| title: Alert title |
| description: Alert description |
| severity: Alert severity |
| category: Alert category |
| source_url: Source URL for the alert |
| threat_id: ID of related threat |
| mention_id: ID of related dark web mention |
| |
| Returns: |
| Optional[Alert]: Created alert or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(create_alert( |
| db=session, |
| title=title, |
| description=description, |
| severity=severity, |
| category=category, |
| source_url=source_url, |
| threat_id=threat_id, |
| mention_id=mention_id, |
| )) |
|
|
|
|
| def update_alert( |
| alert_id: int, |
| status: AlertStatus, |
| action_taken: Optional[str] = None, |
| ) -> Optional[Alert]: |
| """ |
| Update alert status. |
| |
| Args: |
| alert_id: Alert ID |
| status: New status |
| action_taken: Description of action taken |
| |
| Returns: |
| Optional[Alert]: Updated alert or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(update_alert_status( |
| db=session, |
| alert_id=alert_id, |
| status=status, |
| action_taken=action_taken, |
| )) |
|
|
|
|
| def get_alert_severity_counts( |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> Dict[str, int]: |
| """ |
| Get count of alerts by severity. |
| |
| Args: |
| from_date: Filter by generated_at >= from_date |
| to_date: Filter by generated_at <= to_date |
| |
| Returns: |
| Dict[str, int]: Mapping of severity to count |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return {} |
| |
| return run_async(get_alert_counts_by_severity( |
| db=session, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
|
|
|
|
| |
| def get_threats_df( |
| page: int = 1, |
| size: int = 10, |
| severity: Optional[List[ThreatSeverity]] = None, |
| status: Optional[List[ThreatStatus]] = None, |
| category: Optional[List[ThreatCategory]] = None, |
| search_query: Optional[str] = None, |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> pd.DataFrame: |
| """ |
| Get threats as a DataFrame. |
| |
| Args: |
| page: Page number |
| size: Page size |
| severity: Filter by severity |
| status: Filter by status |
| category: Filter by category |
| search_query: Search in title and description |
| from_date: Filter by discovered_at >= from_date |
| to_date: Filter by discovered_at <= to_date |
| |
| Returns: |
| pd.DataFrame: DataFrame with threats |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return pd.DataFrame() |
| |
| threats = run_async(get_threats( |
| db=session, |
| pagination=PaginationParams(page=page, size=size), |
| severity=severity, |
| status=status, |
| category=category, |
| search_query=search_query, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
| |
| if not threats: |
| return pd.DataFrame() |
| |
| |
| data = [] |
| for threat in threats: |
| data.append({ |
| "id": threat.id, |
| "title": threat.title, |
| "description": threat.description, |
| "severity": threat.severity.value if threat.severity else None, |
| "status": threat.status.value if threat.status else None, |
| "category": threat.category.value if threat.category else None, |
| "source_url": threat.source_url, |
| "source_name": threat.source_name, |
| "source_type": threat.source_type, |
| "discovered_at": threat.discovered_at, |
| "affected_entity": threat.affected_entity, |
| "affected_entity_type": threat.affected_entity_type, |
| "confidence_score": threat.confidence_score, |
| "risk_score": threat.risk_score, |
| }) |
| |
| return pd.DataFrame(data) |
|
|
|
|
| def add_threat( |
| title: str, |
| description: str, |
| severity: ThreatSeverity, |
| category: ThreatCategory, |
| status: ThreatStatus = ThreatStatus.NEW, |
| source_url: Optional[str] = None, |
| source_name: Optional[str] = None, |
| source_type: Optional[str] = None, |
| affected_entity: Optional[str] = None, |
| affected_entity_type: Optional[str] = None, |
| confidence_score: float = 0.0, |
| risk_score: float = 0.0, |
| ) -> Optional[Threat]: |
| """ |
| Add a new threat. |
| |
| Args: |
| title: Threat title |
| description: Threat description |
| severity: Threat severity |
| category: Threat category |
| status: Threat status |
| source_url: URL of the source |
| source_name: Name of the source |
| source_type: Type of source |
| affected_entity: Name of affected entity |
| affected_entity_type: Type of affected entity |
| confidence_score: Confidence score (0-1) |
| risk_score: Risk score (0-1) |
| |
| Returns: |
| Optional[Threat]: Created threat or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(create_threat( |
| db=session, |
| title=title, |
| description=description, |
| severity=severity, |
| category=category, |
| status=status, |
| source_url=source_url, |
| source_name=source_name, |
| source_type=source_type, |
| affected_entity=affected_entity, |
| affected_entity_type=affected_entity_type, |
| confidence_score=confidence_score, |
| risk_score=risk_score, |
| )) |
|
|
|
|
| def add_indicator( |
| threat_id: int, |
| value: str, |
| indicator_type: IndicatorType, |
| description: Optional[str] = None, |
| is_verified: bool = False, |
| context: Optional[str] = None, |
| source: Optional[str] = None, |
| ) -> Optional[Indicator]: |
| """ |
| Add an indicator to a threat. |
| |
| Args: |
| threat_id: Threat ID |
| value: Indicator value |
| indicator_type: Indicator type |
| description: Indicator description |
| is_verified: Whether the indicator is verified |
| context: Context of the indicator |
| source: Source of the indicator |
| |
| Returns: |
| Optional[Indicator]: Created indicator or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(add_indicator_to_threat( |
| db=session, |
| threat_id=threat_id, |
| value=value, |
| indicator_type=indicator_type, |
| description=description, |
| is_verified=is_verified, |
| context=context, |
| source=source, |
| )) |
|
|
|
|
| def get_threat_stats( |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Get threat statistics. |
| |
| Args: |
| from_date: Filter by discovered_at >= from_date |
| to_date: Filter by discovered_at <= to_date |
| |
| Returns: |
| Dict[str, Any]: Threat statistics |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return {} |
| |
| return run_async(get_threat_statistics( |
| db=session, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
|
|
|
|
| |
| def get_reports_df( |
| page: int = 1, |
| size: int = 10, |
| report_type: Optional[List[ReportType]] = None, |
| status: Optional[List[ReportStatus]] = None, |
| severity: Optional[List[ThreatSeverity]] = None, |
| search_query: Optional[str] = None, |
| from_date: Optional[datetime] = None, |
| to_date: Optional[datetime] = None, |
| ) -> pd.DataFrame: |
| """ |
| Get reports as a DataFrame. |
| |
| Args: |
| page: Page number |
| size: Page size |
| report_type: Filter by report type |
| status: Filter by status |
| severity: Filter by severity |
| search_query: Search in title and summary |
| from_date: Filter by created_at >= from_date |
| to_date: Filter by created_at <= to_date |
| |
| Returns: |
| pd.DataFrame: DataFrame with reports |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return pd.DataFrame() |
| |
| reports = run_async(get_reports( |
| db=session, |
| pagination=PaginationParams(page=page, size=size), |
| report_type=report_type, |
| status=status, |
| severity=severity, |
| search_query=search_query, |
| from_date=from_date, |
| to_date=to_date, |
| )) |
| |
| if not reports: |
| return pd.DataFrame() |
| |
| |
| data = [] |
| for report in reports: |
| data.append({ |
| "id": report.id, |
| "report_id": report.report_id, |
| "title": report.title, |
| "summary": report.summary, |
| "report_type": report.report_type.value if report.report_type else None, |
| "status": report.status.value if report.status else None, |
| "severity": report.severity.value if report.severity else None, |
| "publish_date": report.publish_date, |
| "created_at": report.created_at, |
| "time_period_start": report.time_period_start, |
| "time_period_end": report.time_period_end, |
| "author_id": report.author_id, |
| }) |
| |
| return pd.DataFrame(data) |
|
|
|
|
| def add_report( |
| title: str, |
| summary: str, |
| content: str, |
| report_type: ReportType, |
| report_id: str, |
| status: ReportStatus = ReportStatus.DRAFT, |
| severity: Optional[ThreatSeverity] = None, |
| publish_date: Optional[datetime] = None, |
| time_period_start: Optional[datetime] = None, |
| time_period_end: Optional[datetime] = None, |
| keywords: Optional[List[str]] = None, |
| author_id: Optional[int] = None, |
| ) -> Optional[Report]: |
| """ |
| Add a new report. |
| |
| Args: |
| title: Report title |
| summary: Report summary |
| content: Report content |
| report_type: Type of report |
| report_id: Custom ID for the report |
| status: Report status |
| severity: Report severity |
| publish_date: Publication date |
| time_period_start: Start of time period covered |
| time_period_end: End of time period covered |
| keywords: List of keywords related to the report |
| author_id: ID of the report author |
| |
| Returns: |
| Optional[Report]: Created report or None |
| """ |
| session = get_db_session() |
| |
| if not session: |
| return None |
| |
| return run_async(create_report( |
| db=session, |
| title=title, |
| summary=summary, |
| content=content, |
| report_type=report_type, |
| report_id=report_id, |
| status=status, |
| severity=severity, |
| publish_date=publish_date, |
| time_period_start=time_period_start, |
| time_period_end=time_period_end, |
| keywords=keywords, |
| author_id=author_id, |
| )) |
|
|
|
|
| |
| def get_time_range_dates(time_range: str) -> Tuple[datetime, datetime]: |
| """ |
| Get start and end dates for a time range. |
| |
| Args: |
| time_range: Time range string (e.g., "Last 7 Days") |
| |
| Returns: |
| Tuple[datetime, datetime]: (start_date, end_date) |
| """ |
| end_date = datetime.utcnow() |
| |
| if time_range == "Last 24 Hours": |
| start_date = end_date - timedelta(days=1) |
| elif time_range == "Last 7 Days": |
| start_date = end_date - timedelta(days=7) |
| elif time_range == "Last 30 Days": |
| start_date = end_date - timedelta(days=30) |
| elif time_range == "Last Quarter": |
| start_date = end_date - timedelta(days=90) |
| else: |
| start_date = end_date - timedelta(days=30) |
| |
| return start_date, end_date |
|
|
|
|
| |
| def init_db_connection(): |
| """Initialize database connection and check if tables exist.""" |
| session = get_db_session() |
| |
| if not session: |
| return False |
| |
| |
| from sqlalchemy.future import select |
| |
| try: |
| |
| from sqlalchemy import text |
| query = text("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')") |
| result = run_async(session.execute(query)) |
| exists = result.scalar() |
| |
| return exists |
| except Exception as e: |
| |
| print(f"Error checking database: {e}") |
| return False |