Spaces:
Sleeping
Sleeping
File size: 7,848 Bytes
b821944 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
Database module for the Data Insights App.
Handles data loading, querying with safety checks, and operations on pandas DataFrame.
"""
import pandas as pd
from typing import Optional, Dict, List, Any
import re
from utils import setup_logger, log_safety_block
from config import CSV_FILE_PATH, DANGEROUS_SQL_KEYWORDS
logger = setup_logger(__name__)
class DataManager:
"""
Manages data operations with safety features.
Loads CSV into pandas DataFrame and provides safe query interface.
"""
def __init__(self, csv_path: str = CSV_FILE_PATH):
"""
Initializes DataManager and loads CSV data.
Inputs: csv_path (string)
Outputs: None
"""
self.csv_path = csv_path
self.df: Optional[pd.DataFrame] = None
self.load_data()
logger.info(f"DataManager initialized with {len(self.df)} rows")
def load_data(self) -> None:
"""
Loads CSV file into pandas DataFrame.
Inputs: None
Outputs: None
"""
try:
self.df = pd.read_csv(self.csv_path)
logger.info(f"Successfully loaded data from {self.csv_path}")
logger.info(f"Columns: {list(self.df.columns)}")
logger.info(f"Shape: {self.df.shape}")
except FileNotFoundError:
logger.error(f"CSV file not found: {self.csv_path}")
raise
except Exception as e:
logger.error(f"Error loading CSV: {str(e)}")
raise
def is_safe_operation(self, query: str) -> bool:
"""
Checks if a query contains dangerous SQL keywords.
Blocks all write operations (INSERT, UPDATE, DELETE, DROP, TRUNCATE, ALTER, etc.)
Inputs: query (string)
Outputs: boolean (True if safe, False if dangerous)
"""
query_upper = query.upper()
for keyword in DANGEROUS_SQL_KEYWORDS:
# Check for keyword as whole word
if re.search(rf'\b{keyword}\b', query_upper):
log_safety_block(query, f"Contains dangerous keyword: {keyword}")
return False
logger.info("Query passed safety check")
return True
def get_dataframe(self) -> pd.DataFrame:
"""
Returns the entire DataFrame (for internal use only, not for LLM).
Inputs: None
Outputs: pandas DataFrame
"""
return self.df.copy()
def get_summary_stats(self) -> Dict[str, Any]:
"""
Returns summary statistics about the dataset.
Inputs: None
Outputs: dictionary with statistics
"""
stats = {
"total_rows": len(self.df),
"total_columns": len(self.df.columns),
"columns": list(self.df.columns),
"numeric_columns": list(self.df.select_dtypes(include=['number']).columns),
"categorical_columns": list(self.df.select_dtypes(include=['object']).columns),
"missing_values": self.df.isnull().sum().to_dict(),
"avg_price": float(self.df['price_usd'].mean()) if 'price_usd' in self.df.columns else None,
"min_price": float(self.df['price_usd'].min()) if 'price_usd' in self.df.columns else None,
"max_price": float(self.df['price_usd'].max()) if 'price_usd' in self.df.columns else None,
}
logger.info("Generated summary statistics")
return stats
def filter_data(self, filters: Dict[str, Any], limit: int = 20) -> List[Dict]:
"""
Filters data based on provided criteria.
Inputs: filters (dict), limit (int)
Outputs: list of dictionaries (filtered results)
"""
logger.info(f"Filtering data with criteria: {filters}")
df_filtered = self.df.copy()
# Apply filters
for column, value in filters.items():
if column not in df_filtered.columns:
logger.warning(f"Column '{column}' not found in dataset")
continue
# Handle different filter types
if isinstance(value, dict):
# Range filter (e.g., {"min": 100, "max": 500})
if "min" in value and "max" in value:
df_filtered = df_filtered[
(df_filtered[column] >= value["min"]) &
(df_filtered[column] <= value["max"])
]
elif "min" in value:
df_filtered = df_filtered[df_filtered[column] >= value["min"]]
elif "max" in value:
df_filtered = df_filtered[df_filtered[column] <= value["max"]]
elif isinstance(value, list):
# Multiple values (e.g., ["Apple", "Samsung"])
df_filtered = df_filtered[df_filtered[column].isin(value)]
else:
# Single value
df_filtered = df_filtered[df_filtered[column] == value]
# Limit results
df_filtered = df_filtered.head(limit)
result = df_filtered.to_dict('records')
logger.info(f"Filter returned {len(result)} results (limit: {limit})")
return result
def aggregate_data(self, group_by: str, metric: str, aggregation: str = "mean") -> List[Dict]:
"""
Aggregates data by grouping and calculating metrics.
Inputs: group_by (string), metric (string), aggregation (string)
Outputs: list of dictionaries (aggregated results)
"""
logger.info(f"Aggregating data: group_by={group_by}, metric={metric}, agg={aggregation}")
if group_by not in self.df.columns:
logger.error(f"Group by column '{group_by}' not found")
return []
if metric not in self.df.columns:
logger.error(f"Metric column '{metric}' not found")
return []
try:
if aggregation == "mean":
result_df = self.df.groupby(group_by)[metric].mean().reset_index()
result_df.columns = [group_by, f"avg_{metric}"]
elif aggregation == "sum":
result_df = self.df.groupby(group_by)[metric].sum().reset_index()
result_df.columns = [group_by, f"sum_{metric}"]
elif aggregation == "count":
result_df = self.df.groupby(group_by)[metric].count().reset_index()
result_df.columns = [group_by, f"count_{metric}"]
elif aggregation == "min":
result_df = self.df.groupby(group_by)[metric].min().reset_index()
result_df.columns = [group_by, f"min_{metric}"]
elif aggregation == "max":
result_df = self.df.groupby(group_by)[metric].max().reset_index()
result_df.columns = [group_by, f"max_{metric}"]
else:
logger.error(f"Unknown aggregation type: {aggregation}")
return []
result = result_df.to_dict('records')
logger.info(f"Aggregation returned {len(result)} results")
return result
except Exception as e:
logger.error(f"Error during aggregation: {str(e)}")
return []
def get_unique_values(self, column: str) -> List[Any]:
"""
Returns unique values for a specific column.
Inputs: column (string)
Outputs: list of unique values
"""
if column not in self.df.columns:
logger.error(f"Column '{column}' not found")
return []
unique_vals = self.df[column].unique().tolist()
logger.info(f"Found {len(unique_vals)} unique values for column '{column}'")
return unique_vals[:50] # Limit to 50 unique values
|