Spaces:
Sleeping
Sleeping
| """ | |
| Database module for the Data Insights App. | |
| Handles data loading, querying with safety checks, and operations on pandas DataFrame. | |
| """ | |
| import pandas as pd | |
| from typing import Optional, Dict, List, Any | |
| import re | |
| from utils import setup_logger, log_safety_block | |
| from config import CSV_FILE_PATH, DANGEROUS_SQL_KEYWORDS | |
| logger = setup_logger(__name__) | |
| class DataManager: | |
| """ | |
| Manages data operations with safety features. | |
| Loads CSV into pandas DataFrame and provides safe query interface. | |
| """ | |
| def __init__(self, csv_path: str = CSV_FILE_PATH): | |
| """ | |
| Initializes DataManager and loads CSV data. | |
| Inputs: csv_path (string) | |
| Outputs: None | |
| """ | |
| self.csv_path = csv_path | |
| self.df: Optional[pd.DataFrame] = None | |
| self.load_data() | |
| logger.info(f"DataManager initialized with {len(self.df)} rows") | |
| def load_data(self) -> None: | |
| """ | |
| Loads CSV file into pandas DataFrame. | |
| Inputs: None | |
| Outputs: None | |
| """ | |
| try: | |
| self.df = pd.read_csv(self.csv_path) | |
| logger.info(f"Successfully loaded data from {self.csv_path}") | |
| logger.info(f"Columns: {list(self.df.columns)}") | |
| logger.info(f"Shape: {self.df.shape}") | |
| except FileNotFoundError: | |
| logger.error(f"CSV file not found: {self.csv_path}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading CSV: {str(e)}") | |
| raise | |
| def is_safe_operation(self, query: str) -> bool: | |
| """ | |
| Checks if a query contains dangerous SQL keywords. | |
| Blocks all write operations (INSERT, UPDATE, DELETE, DROP, TRUNCATE, ALTER, etc.) | |
| Inputs: query (string) | |
| Outputs: boolean (True if safe, False if dangerous) | |
| """ | |
| query_upper = query.upper() | |
| for keyword in DANGEROUS_SQL_KEYWORDS: | |
| # Check for keyword as whole word | |
| if re.search(rf'\b{keyword}\b', query_upper): | |
| log_safety_block(query, f"Contains dangerous keyword: {keyword}") | |
| return False | |
| logger.info("Query passed safety check") | |
| return True | |
| def get_dataframe(self) -> pd.DataFrame: | |
| """ | |
| Returns the entire DataFrame (for internal use only, not for LLM). | |
| Inputs: None | |
| Outputs: pandas DataFrame | |
| """ | |
| return self.df.copy() | |
| def get_summary_stats(self) -> Dict[str, Any]: | |
| """ | |
| Returns summary statistics about the dataset. | |
| Inputs: None | |
| Outputs: dictionary with statistics | |
| """ | |
| stats = { | |
| "total_rows": len(self.df), | |
| "total_columns": len(self.df.columns), | |
| "columns": list(self.df.columns), | |
| "numeric_columns": list(self.df.select_dtypes(include=['number']).columns), | |
| "categorical_columns": list(self.df.select_dtypes(include=['object']).columns), | |
| "missing_values": self.df.isnull().sum().to_dict(), | |
| "avg_price": float(self.df['price_usd'].mean()) if 'price_usd' in self.df.columns else None, | |
| "min_price": float(self.df['price_usd'].min()) if 'price_usd' in self.df.columns else None, | |
| "max_price": float(self.df['price_usd'].max()) if 'price_usd' in self.df.columns else None, | |
| } | |
| logger.info("Generated summary statistics") | |
| return stats | |
| def filter_data(self, filters: Dict[str, Any], limit: int = 20) -> List[Dict]: | |
| """ | |
| Filters data based on provided criteria. | |
| Inputs: filters (dict), limit (int) | |
| Outputs: list of dictionaries (filtered results) | |
| """ | |
| logger.info(f"Filtering data with criteria: {filters}") | |
| df_filtered = self.df.copy() | |
| # Apply filters | |
| for column, value in filters.items(): | |
| if column not in df_filtered.columns: | |
| logger.warning(f"Column '{column}' not found in dataset") | |
| continue | |
| # Handle different filter types | |
| if isinstance(value, dict): | |
| # Range filter (e.g., {"min": 100, "max": 500}) | |
| if "min" in value and "max" in value: | |
| df_filtered = df_filtered[ | |
| (df_filtered[column] >= value["min"]) & | |
| (df_filtered[column] <= value["max"]) | |
| ] | |
| elif "min" in value: | |
| df_filtered = df_filtered[df_filtered[column] >= value["min"]] | |
| elif "max" in value: | |
| df_filtered = df_filtered[df_filtered[column] <= value["max"]] | |
| elif isinstance(value, list): | |
| # Multiple values (e.g., ["Apple", "Samsung"]) | |
| df_filtered = df_filtered[df_filtered[column].isin(value)] | |
| else: | |
| # Single value | |
| df_filtered = df_filtered[df_filtered[column] == value] | |
| # Limit results | |
| df_filtered = df_filtered.head(limit) | |
| result = df_filtered.to_dict('records') | |
| logger.info(f"Filter returned {len(result)} results (limit: {limit})") | |
| return result | |
| def aggregate_data(self, group_by: str, metric: str, aggregation: str = "mean") -> List[Dict]: | |
| """ | |
| Aggregates data by grouping and calculating metrics. | |
| Inputs: group_by (string), metric (string), aggregation (string) | |
| Outputs: list of dictionaries (aggregated results) | |
| """ | |
| logger.info(f"Aggregating data: group_by={group_by}, metric={metric}, agg={aggregation}") | |
| if group_by not in self.df.columns: | |
| logger.error(f"Group by column '{group_by}' not found") | |
| return [] | |
| if metric not in self.df.columns: | |
| logger.error(f"Metric column '{metric}' not found") | |
| return [] | |
| try: | |
| if aggregation == "mean": | |
| result_df = self.df.groupby(group_by)[metric].mean().reset_index() | |
| result_df.columns = [group_by, f"avg_{metric}"] | |
| elif aggregation == "sum": | |
| result_df = self.df.groupby(group_by)[metric].sum().reset_index() | |
| result_df.columns = [group_by, f"sum_{metric}"] | |
| elif aggregation == "count": | |
| result_df = self.df.groupby(group_by)[metric].count().reset_index() | |
| result_df.columns = [group_by, f"count_{metric}"] | |
| elif aggregation == "min": | |
| result_df = self.df.groupby(group_by)[metric].min().reset_index() | |
| result_df.columns = [group_by, f"min_{metric}"] | |
| elif aggregation == "max": | |
| result_df = self.df.groupby(group_by)[metric].max().reset_index() | |
| result_df.columns = [group_by, f"max_{metric}"] | |
| else: | |
| logger.error(f"Unknown aggregation type: {aggregation}") | |
| return [] | |
| result = result_df.to_dict('records') | |
| logger.info(f"Aggregation returned {len(result)} results") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error during aggregation: {str(e)}") | |
| return [] | |
| def get_unique_values(self, column: str) -> List[Any]: | |
| """ | |
| Returns unique values for a specific column. | |
| Inputs: column (string) | |
| Outputs: list of unique values | |
| """ | |
| if column not in self.df.columns: | |
| logger.error(f"Column '{column}' not found") | |
| return [] | |
| unique_vals = self.df[column].unique().tolist() | |
| logger.info(f"Found {len(unique_vals)} unique values for column '{column}'") | |
| return unique_vals[:50] # Limit to 50 unique values | |