File size: 7,848 Bytes
b821944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Database module for the Data Insights App.
Handles data loading, querying with safety checks, and operations on pandas DataFrame.
"""
import pandas as pd
from typing import Optional, Dict, List, Any
import re
from utils import setup_logger, log_safety_block
from config import CSV_FILE_PATH, DANGEROUS_SQL_KEYWORDS

logger = setup_logger(__name__)


class DataManager:
    """
    Manages data operations with safety features.
    Loads CSV into pandas DataFrame and provides safe query interface.
    """
    
    def __init__(self, csv_path: str = CSV_FILE_PATH):
        """
        Initializes DataManager and loads CSV data.
        Inputs: csv_path (string)
        Outputs: None
        """
        self.csv_path = csv_path
        self.df: Optional[pd.DataFrame] = None
        self.load_data()
        logger.info(f"DataManager initialized with {len(self.df)} rows")
    
    def load_data(self) -> None:
        """
        Loads CSV file into pandas DataFrame.
        Inputs: None
        Outputs: None
        """
        try:
            self.df = pd.read_csv(self.csv_path)
            logger.info(f"Successfully loaded data from {self.csv_path}")
            logger.info(f"Columns: {list(self.df.columns)}")
            logger.info(f"Shape: {self.df.shape}")
        except FileNotFoundError:
            logger.error(f"CSV file not found: {self.csv_path}")
            raise
        except Exception as e:
            logger.error(f"Error loading CSV: {str(e)}")
            raise
    
    def is_safe_operation(self, query: str) -> bool:
        """
        Checks if a query contains dangerous SQL keywords.
        Blocks all write operations (INSERT, UPDATE, DELETE, DROP, TRUNCATE, ALTER, etc.)
        Inputs: query (string)
        Outputs: boolean (True if safe, False if dangerous)
        """
        query_upper = query.upper()
        
        for keyword in DANGEROUS_SQL_KEYWORDS:
            # Check for keyword as whole word
            if re.search(rf'\b{keyword}\b', query_upper):
                log_safety_block(query, f"Contains dangerous keyword: {keyword}")
                return False
        
        logger.info("Query passed safety check")
        return True
    
    def get_dataframe(self) -> pd.DataFrame:
        """
        Returns the entire DataFrame (for internal use only, not for LLM).
        Inputs: None
        Outputs: pandas DataFrame
        """
        return self.df.copy()
    
    def get_summary_stats(self) -> Dict[str, Any]:
        """
        Returns summary statistics about the dataset.
        Inputs: None
        Outputs: dictionary with statistics
        """
        stats = {
            "total_rows": len(self.df),
            "total_columns": len(self.df.columns),
            "columns": list(self.df.columns),
            "numeric_columns": list(self.df.select_dtypes(include=['number']).columns),
            "categorical_columns": list(self.df.select_dtypes(include=['object']).columns),
            "missing_values": self.df.isnull().sum().to_dict(),
            "avg_price": float(self.df['price_usd'].mean()) if 'price_usd' in self.df.columns else None,
            "min_price": float(self.df['price_usd'].min()) if 'price_usd' in self.df.columns else None,
            "max_price": float(self.df['price_usd'].max()) if 'price_usd' in self.df.columns else None,
        }
        logger.info("Generated summary statistics")
        return stats
    
    def filter_data(self, filters: Dict[str, Any], limit: int = 20) -> List[Dict]:
        """
        Filters data based on provided criteria.
        Inputs: filters (dict), limit (int)
        Outputs: list of dictionaries (filtered results)
        """
        logger.info(f"Filtering data with criteria: {filters}")
        
        df_filtered = self.df.copy()
        
        # Apply filters
        for column, value in filters.items():
            if column not in df_filtered.columns:
                logger.warning(f"Column '{column}' not found in dataset")
                continue
            
            # Handle different filter types
            if isinstance(value, dict):
                # Range filter (e.g., {"min": 100, "max": 500})
                if "min" in value and "max" in value:
                    df_filtered = df_filtered[
                        (df_filtered[column] >= value["min"]) & 
                        (df_filtered[column] <= value["max"])
                    ]
                elif "min" in value:
                    df_filtered = df_filtered[df_filtered[column] >= value["min"]]
                elif "max" in value:
                    df_filtered = df_filtered[df_filtered[column] <= value["max"]]
            elif isinstance(value, list):
                # Multiple values (e.g., ["Apple", "Samsung"])
                df_filtered = df_filtered[df_filtered[column].isin(value)]
            else:
                # Single value
                df_filtered = df_filtered[df_filtered[column] == value]
        
        # Limit results
        df_filtered = df_filtered.head(limit)
        
        result = df_filtered.to_dict('records')
        logger.info(f"Filter returned {len(result)} results (limit: {limit})")
        
        return result
    
    def aggregate_data(self, group_by: str, metric: str, aggregation: str = "mean") -> List[Dict]:
        """
        Aggregates data by grouping and calculating metrics.
        Inputs: group_by (string), metric (string), aggregation (string)
        Outputs: list of dictionaries (aggregated results)
        """
        logger.info(f"Aggregating data: group_by={group_by}, metric={metric}, agg={aggregation}")
        
        if group_by not in self.df.columns:
            logger.error(f"Group by column '{group_by}' not found")
            return []
        
        if metric not in self.df.columns:
            logger.error(f"Metric column '{metric}' not found")
            return []
        
        try:
            if aggregation == "mean":
                result_df = self.df.groupby(group_by)[metric].mean().reset_index()
                result_df.columns = [group_by, f"avg_{metric}"]
            elif aggregation == "sum":
                result_df = self.df.groupby(group_by)[metric].sum().reset_index()
                result_df.columns = [group_by, f"sum_{metric}"]
            elif aggregation == "count":
                result_df = self.df.groupby(group_by)[metric].count().reset_index()
                result_df.columns = [group_by, f"count_{metric}"]
            elif aggregation == "min":
                result_df = self.df.groupby(group_by)[metric].min().reset_index()
                result_df.columns = [group_by, f"min_{metric}"]
            elif aggregation == "max":
                result_df = self.df.groupby(group_by)[metric].max().reset_index()
                result_df.columns = [group_by, f"max_{metric}"]
            else:
                logger.error(f"Unknown aggregation type: {aggregation}")
                return []
            
            result = result_df.to_dict('records')
            logger.info(f"Aggregation returned {len(result)} results")
            
            return result
        except Exception as e:
            logger.error(f"Error during aggregation: {str(e)}")
            return []
    
    def get_unique_values(self, column: str) -> List[Any]:
        """
        Returns unique values for a specific column.
        Inputs: column (string)
        Outputs: list of unique values
        """
        if column not in self.df.columns:
            logger.error(f"Column '{column}' not found")
            return []
        
        unique_vals = self.df[column].unique().tolist()
        logger.info(f"Found {len(unique_vals)} unique values for column '{column}'")
        
        return unique_vals[:50]  # Limit to 50 unique values