jzou19950715 commited on
Commit
48a1160
·
verified ·
1 Parent(s): 4e06409

Create tools.py

Browse files
Files changed (1) hide show
  1. tools.py +339 -0
tools.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ """
5
+ Analysis and visualization tools for data analysis assistant.
6
+ Provides a collection of tools for data analysis, statistical computations,
7
+ and interactive visualizations using Plotly.
8
+ """
9
+
10
+ import logging
11
+ from typing import Any, Dict, List, Optional, Tuple, Union
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ import plotly.express as px
18
+ import plotly.graph_objects as go
19
+ from plotly.subplots import make_subplots
20
+ import seaborn as sns
21
+ from scipy import stats
22
+ from smolagents import tool
23
+
24
+ # Configure logging
25
+ logging.basicConfig(
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
27
+ level=logging.INFO
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class AnalysisError(Exception):
32
+ """Custom exception for analysis errors."""
33
+ pass
34
+
35
+ @tool
36
+ def create_time_series_plot(
37
+ df: pd.DataFrame,
38
+ time_column: str,
39
+ value_column: str,
40
+ title: Optional[str] = None
41
+ ) -> Dict[str, Any]:
42
+ """
43
+ Create an interactive time series plot.
44
+
45
+ Args:
46
+ df: Input DataFrame
47
+ time_column: Name of the time column
48
+ value_column: Name of the value column to plot
49
+ title: Optional title for the plot
50
+
51
+ Returns:
52
+ Dict containing the plotly figure and stats
53
+ """
54
+ try:
55
+ # Validate inputs
56
+ if time_column not in df.columns or value_column not in df.columns:
57
+ raise AnalysisError(f"Columns {time_column} or {value_column} not found in DataFrame")
58
+
59
+ # Create plot
60
+ fig = px.line(
61
+ df,
62
+ x=time_column,
63
+ y=value_column,
64
+ title=title or f"{value_column} over Time",
65
+ template="plotly_white"
66
+ )
67
+
68
+ # Add hover data
69
+ fig.update_traces(
70
+ hovertemplate=(
71
+ f"{time_column}: %{{x}}<br>"
72
+ f"{value_column}: %{{y:.2f}}<br>"
73
+ "<extra></extra>"
74
+ )
75
+ )
76
+
77
+ # Calculate basic stats
78
+ stats_dict = {
79
+ "mean": df[value_column].mean(),
80
+ "std": df[value_column].std(),
81
+ "min": df[value_column].min(),
82
+ "max": df[value_column].max()
83
+ }
84
+
85
+ return {"figure": fig, "stats": stats_dict}
86
+
87
+ except Exception as e:
88
+ logger.error(f"Error in create_time_series_plot: {str(e)}")
89
+ raise AnalysisError(f"Failed to create time series plot: {str(e)}")
90
+
91
+ @tool
92
+ def create_correlation_heatmap(df: pd.DataFrame, numeric_only: bool = True) -> Dict[str, Any]:
93
+ """
94
+ Create an interactive correlation heatmap.
95
+
96
+ Args:
97
+ df: Input DataFrame
98
+ numeric_only: Whether to include only numeric columns
99
+
100
+ Returns:
101
+ Dict containing the plotly figure and correlation matrix
102
+ """
103
+ try:
104
+ # Select numeric columns if requested
105
+ if numeric_only:
106
+ df = df.select_dtypes(include=[np.number])
107
+
108
+ # Calculate correlation matrix
109
+ corr_matrix = df.corr()
110
+
111
+ # Create heatmap
112
+ fig = go.Figure(data=go.Heatmap(
113
+ z=corr_matrix,
114
+ x=corr_matrix.columns,
115
+ y=corr_matrix.columns,
116
+ colorscale='RdBu',
117
+ zmid=0,
118
+ text=np.round(corr_matrix, 2),
119
+ texttemplate='%{text:.2f}',
120
+ textfont={"size": 10},
121
+ hoverongaps=False
122
+ ))
123
+
124
+ # Update layout
125
+ fig.update_layout(
126
+ title="Correlation Heatmap",
127
+ template="plotly_white",
128
+ width=800,
129
+ height=800
130
+ )
131
+
132
+ return {
133
+ "figure": fig,
134
+ "correlation_matrix": corr_matrix.to_dict()
135
+ }
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error in create_correlation_heatmap: {str(e)}")
139
+ raise AnalysisError(f"Failed to create correlation heatmap: {str(e)}")
140
+
141
+ @tool
142
+ def create_statistical_summary(df: pd.DataFrame, column: str) -> Dict[str, Any]:
143
+ """
144
+ Create statistical summary with visualization for a column.
145
+
146
+ Args:
147
+ df: Input DataFrame
148
+ column: Column name to analyze
149
+
150
+ Returns:
151
+ Dict containing summary statistics and visualization
152
+ """
153
+ try:
154
+ if column not in df.columns:
155
+ raise AnalysisError(f"Column {column} not found in DataFrame")
156
+
157
+ # Calculate summary statistics
158
+ summary_stats = df[column].describe().to_dict()
159
+
160
+ # Additional statistics
161
+ if pd.api.types.is_numeric_dtype(df[column]):
162
+ summary_stats.update({
163
+ "skewness": stats.skew(df[column].dropna()),
164
+ "kurtosis": stats.kurtosis(df[column].dropna())
165
+ })
166
+
167
+ # Create distribution plot
168
+ fig = make_subplots(rows=2, cols=1)
169
+
170
+ # Add histogram
171
+ fig.add_trace(
172
+ go.Histogram(
173
+ x=df[column],
174
+ name="Distribution",
175
+ nbinsx=30
176
+ ),
177
+ row=1, col=1
178
+ )
179
+
180
+ # Add box plot
181
+ fig.add_trace(
182
+ go.Box(
183
+ y=df[column],
184
+ name="Box Plot"
185
+ ),
186
+ row=2, col=1
187
+ )
188
+
189
+ # Update layout
190
+ fig.update_layout(
191
+ title=f"Statistical Analysis of {column}",
192
+ showlegend=False,
193
+ template="plotly_white",
194
+ height=800
195
+ )
196
+
197
+ return {
198
+ "figure": fig,
199
+ "stats": summary_stats
200
+ }
201
+
202
+ except Exception as e:
203
+ logger.error(f"Error in create_statistical_summary: {str(e)}")
204
+ raise AnalysisError(f"Failed to create statistical summary: {str(e)}")
205
+
206
+ @tool
207
+ def detect_outliers(
208
+ df: pd.DataFrame,
209
+ column: str,
210
+ method: str = "zscore",
211
+ threshold: float = 3.0
212
+ ) -> Dict[str, Any]:
213
+ """
214
+ Detect outliers in a column using various methods.
215
+
216
+ Args:
217
+ df: Input DataFrame
218
+ column: Column to analyze
219
+ method: Detection method ('zscore' or 'iqr')
220
+ threshold: Threshold for outlier detection
221
+
222
+ Returns:
223
+ Dict containing outlier indices and visualization
224
+ """
225
+ try:
226
+ if column not in df.columns:
227
+ raise AnalysisError(f"Column {column} not found in DataFrame")
228
+
229
+ values = df[column].dropna()
230
+
231
+ if method == "zscore":
232
+ z_scores = np.abs(stats.zscore(values))
233
+ outlier_mask = z_scores > threshold
234
+ elif method == "iqr":
235
+ Q1 = values.quantile(0.25)
236
+ Q3 = values.quantile(0.75)
237
+ IQR = Q3 - Q1
238
+ outlier_mask = (values < (Q1 - threshold * IQR)) | (values > (Q3 + threshold * IQR))
239
+ else:
240
+ raise AnalysisError(f"Unknown outlier detection method: {method}")
241
+
242
+ # Create visualization
243
+ fig = go.Figure()
244
+
245
+ # Add main scatter plot
246
+ fig.add_trace(
247
+ go.Scatter(
248
+ x=df.index[~outlier_mask],
249
+ y=values[~outlier_mask],
250
+ mode='markers',
251
+ name='Normal Points',
252
+ marker=dict(color='blue')
253
+ )
254
+ )
255
+
256
+ # Add outliers
257
+ fig.add_trace(
258
+ go.Scatter(
259
+ x=df.index[outlier_mask],
260
+ y=values[outlier_mask],
261
+ mode='markers',
262
+ name='Outliers',
263
+ marker=dict(color='red')
264
+ )
265
+ )
266
+
267
+ fig.update_layout(
268
+ title=f"Outlier Detection for {column}",
269
+ template="plotly_white",
270
+ showlegend=True
271
+ )
272
+
273
+ return {
274
+ "figure": fig,
275
+ "outlier_indices": df.index[outlier_mask].tolist(),
276
+ "outlier_count": sum(outlier_mask)
277
+ }
278
+
279
+ except Exception as e:
280
+ logger.error(f"Error in detect_outliers: {str(e)}")
281
+ raise AnalysisError(f"Failed to detect outliers: {str(e)}")
282
+
283
+ # Additional utility functions
284
+ def validate_dataframe(df: pd.DataFrame) -> Tuple[bool, str]:
285
+ """
286
+ Validate DataFrame for analysis.
287
+
288
+ Args:
289
+ df: Input DataFrame
290
+
291
+ Returns:
292
+ Tuple of (is_valid, error_message)
293
+ """
294
+ if df is None:
295
+ return False, "DataFrame is None"
296
+
297
+ if df.empty:
298
+ return False, "DataFrame is empty"
299
+
300
+ if df.columns.duplicated().any():
301
+ return False, "DataFrame contains duplicate column names"
302
+
303
+ return True, ""
304
+
305
+ def get_numeric_columns(df: pd.DataFrame) -> List[str]:
306
+ """Get list of numeric columns from DataFrame."""
307
+ return df.select_dtypes(include=[np.number]).columns.tolist()
308
+
309
+ def get_temporal_columns(df: pd.DataFrame) -> List[str]:
310
+ """Get list of temporal columns from DataFrame."""
311
+ temporal_cols = []
312
+ for col in df.columns:
313
+ try:
314
+ pd.to_datetime(df[col])
315
+ temporal_cols.append(col)
316
+ except:
317
+ continue
318
+ return temporal_cols
319
+
320
+ if __name__ == "__main__":
321
+ # Example usage and testing
322
+ logging.info("Running tools.py tests...")
323
+
324
+ # Create sample data
325
+ dates = pd.date_range(start='2023-01-01', periods=100, freq='D')
326
+ df = pd.DataFrame({
327
+ 'date': dates,
328
+ 'value': np.random.normal(100, 10, 100),
329
+ 'category': np.random.choice(['A', 'B', 'C'], 100)
330
+ })
331
+
332
+ # Test time series plot
333
+ try:
334
+ result = create_time_series_plot(df, 'date', 'value')
335
+ logging.info("Time series plot created successfully")
336
+ except Exception as e:
337
+ logging.error(f"Time series plot test failed: {str(e)}")
338
+
339
+ # Add more tests as needed