""" Database module for the Data Insights App. Handles data loading, querying with safety checks, and operations on pandas DataFrame. """ import pandas as pd from typing import Optional, Dict, List, Any import re from utils import setup_logger, log_safety_block from config import CSV_FILE_PATH, DANGEROUS_SQL_KEYWORDS logger = setup_logger(__name__) class DataManager: """ Manages data operations with safety features. Loads CSV into pandas DataFrame and provides safe query interface. """ def __init__(self, csv_path: str = CSV_FILE_PATH): """ Initializes DataManager and loads CSV data. Inputs: csv_path (string) Outputs: None """ self.csv_path = csv_path self.df: Optional[pd.DataFrame] = None self.load_data() logger.info(f"DataManager initialized with {len(self.df)} rows") def load_data(self) -> None: """ Loads CSV file into pandas DataFrame. Inputs: None Outputs: None """ try: self.df = pd.read_csv(self.csv_path) logger.info(f"Successfully loaded data from {self.csv_path}") logger.info(f"Columns: {list(self.df.columns)}") logger.info(f"Shape: {self.df.shape}") except FileNotFoundError: logger.error(f"CSV file not found: {self.csv_path}") raise except Exception as e: logger.error(f"Error loading CSV: {str(e)}") raise def is_safe_operation(self, query: str) -> bool: """ Checks if a query contains dangerous SQL keywords. Blocks all write operations (INSERT, UPDATE, DELETE, DROP, TRUNCATE, ALTER, etc.) Inputs: query (string) Outputs: boolean (True if safe, False if dangerous) """ query_upper = query.upper() for keyword in DANGEROUS_SQL_KEYWORDS: # Check for keyword as whole word if re.search(rf'\b{keyword}\b', query_upper): log_safety_block(query, f"Contains dangerous keyword: {keyword}") return False logger.info("Query passed safety check") return True def get_dataframe(self) -> pd.DataFrame: """ Returns the entire DataFrame (for internal use only, not for LLM). Inputs: None Outputs: pandas DataFrame """ return self.df.copy() def get_summary_stats(self) -> Dict[str, Any]: """ Returns summary statistics about the dataset. Inputs: None Outputs: dictionary with statistics """ stats = { "total_rows": len(self.df), "total_columns": len(self.df.columns), "columns": list(self.df.columns), "numeric_columns": list(self.df.select_dtypes(include=['number']).columns), "categorical_columns": list(self.df.select_dtypes(include=['object']).columns), "missing_values": self.df.isnull().sum().to_dict(), "avg_price": float(self.df['price_usd'].mean()) if 'price_usd' in self.df.columns else None, "min_price": float(self.df['price_usd'].min()) if 'price_usd' in self.df.columns else None, "max_price": float(self.df['price_usd'].max()) if 'price_usd' in self.df.columns else None, } logger.info("Generated summary statistics") return stats def filter_data(self, filters: Dict[str, Any], limit: int = 20) -> List[Dict]: """ Filters data based on provided criteria. Inputs: filters (dict), limit (int) Outputs: list of dictionaries (filtered results) """ logger.info(f"Filtering data with criteria: {filters}") df_filtered = self.df.copy() # Apply filters for column, value in filters.items(): if column not in df_filtered.columns: logger.warning(f"Column '{column}' not found in dataset") continue # Handle different filter types if isinstance(value, dict): # Range filter (e.g., {"min": 100, "max": 500}) if "min" in value and "max" in value: df_filtered = df_filtered[ (df_filtered[column] >= value["min"]) & (df_filtered[column] <= value["max"]) ] elif "min" in value: df_filtered = df_filtered[df_filtered[column] >= value["min"]] elif "max" in value: df_filtered = df_filtered[df_filtered[column] <= value["max"]] elif isinstance(value, list): # Multiple values (e.g., ["Apple", "Samsung"]) df_filtered = df_filtered[df_filtered[column].isin(value)] else: # Single value df_filtered = df_filtered[df_filtered[column] == value] # Limit results df_filtered = df_filtered.head(limit) result = df_filtered.to_dict('records') logger.info(f"Filter returned {len(result)} results (limit: {limit})") return result def aggregate_data(self, group_by: str, metric: str, aggregation: str = "mean") -> List[Dict]: """ Aggregates data by grouping and calculating metrics. Inputs: group_by (string), metric (string), aggregation (string) Outputs: list of dictionaries (aggregated results) """ logger.info(f"Aggregating data: group_by={group_by}, metric={metric}, agg={aggregation}") if group_by not in self.df.columns: logger.error(f"Group by column '{group_by}' not found") return [] if metric not in self.df.columns: logger.error(f"Metric column '{metric}' not found") return [] try: if aggregation == "mean": result_df = self.df.groupby(group_by)[metric].mean().reset_index() result_df.columns = [group_by, f"avg_{metric}"] elif aggregation == "sum": result_df = self.df.groupby(group_by)[metric].sum().reset_index() result_df.columns = [group_by, f"sum_{metric}"] elif aggregation == "count": result_df = self.df.groupby(group_by)[metric].count().reset_index() result_df.columns = [group_by, f"count_{metric}"] elif aggregation == "min": result_df = self.df.groupby(group_by)[metric].min().reset_index() result_df.columns = [group_by, f"min_{metric}"] elif aggregation == "max": result_df = self.df.groupby(group_by)[metric].max().reset_index() result_df.columns = [group_by, f"max_{metric}"] else: logger.error(f"Unknown aggregation type: {aggregation}") return [] result = result_df.to_dict('records') logger.info(f"Aggregation returned {len(result)} results") return result except Exception as e: logger.error(f"Error during aggregation: {str(e)}") return [] def get_unique_values(self, column: str) -> List[Any]: """ Returns unique values for a specific column. Inputs: column (string) Outputs: list of unique values """ if column not in self.df.columns: logger.error(f"Column '{column}' not found") return [] unique_vals = self.df[column].unique().tolist() logger.info(f"Found {len(unique_vals)} unique values for column '{column}'") return unique_vals[:50] # Limit to 50 unique values