Chamin09 commited on
Commit
8d9d697
·
verified ·
1 Parent(s): 10ee83d

Create data_tooks.py

Browse files
Files changed (1) hide show
  1. tools/data_tooks.py +152 -0
tools/data_tooks.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Callable
2
+ import pandas as pd
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ class PandasDataTools:
7
+ """Tools for data analysis operations on CSV files."""
8
+
9
+ def __init__(self, csv_directory: str):
10
+ """Initialize with directory containing CSV files."""
11
+ self.csv_directory = csv_directory
12
+ self.dataframes = {}
13
+
14
+ def _load_dataframe(self, filename: str) -> pd.DataFrame:
15
+ """Load a CSV file as DataFrame, with caching."""
16
+ if filename not in self.dataframes:
17
+ file_path = Path(self.csv_directory) / filename
18
+ if not file_path.exists() and not filename.endswith('.csv'):
19
+ file_path = Path(self.csv_directory) / f"{filename}.csv"
20
+
21
+ if file_path.exists():
22
+ self.dataframes[filename] = pd.read_csv(file_path)
23
+ else:
24
+ raise ValueError(f"CSV file not found: {filename}")
25
+
26
+ return self.dataframes[filename]
27
+
28
+ def get_tools(self) -> List[Dict[str, Any]]:
29
+ """Get all available data tools."""
30
+ tools = [
31
+ {
32
+ "name": "describe_csv",
33
+ "description": "Get statistical description of a CSV file",
34
+ "function": self.describe_csv
35
+ },
36
+ {
37
+ "name": "filter_data",
38
+ "description": "Filter CSV data based on conditions",
39
+ "function": self.filter_data
40
+ },
41
+ {
42
+ "name": "group_and_aggregate",
43
+ "description": "Group data and calculate aggregate statistics",
44
+ "function": self.group_and_aggregate
45
+ },
46
+ {
47
+ "name": "sort_data",
48
+ "description": "Sort data by specified columns",
49
+ "function": self.sort_data
50
+ },
51
+ {
52
+ "name": "calculate_correlation",
53
+ "description": "Calculate correlation between columns",
54
+ "function": self.calculate_correlation
55
+ }
56
+ ]
57
+ return tools
58
+
59
+ # Tool implementations
60
+ def describe_csv(self, filename: str) -> Dict[str, Any]:
61
+ """Get statistical description of CSV data."""
62
+ df = self._load_dataframe(filename)
63
+ description = df.describe().to_dict()
64
+
65
+ # Add additional info
66
+ result = {
67
+ "statistics": description,
68
+ "shape": df.shape,
69
+ "columns": df.columns.tolist(),
70
+ "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
71
+ }
72
+
73
+ return result
74
+
75
+ def filter_data(self, filename: str, column: str, condition: str, value: Any) -> Dict[str, Any]:
76
+ """Filter data based on condition (==, >, <, >=, <=, !=, contains)."""
77
+ df = self._load_dataframe(filename)
78
+
79
+ if condition == "==":
80
+ filtered = df[df[column] == value]
81
+ elif condition == ">":
82
+ filtered = df[df[column] > float(value)]
83
+ elif condition == "<":
84
+ filtered = df[df[column] < float(value)]
85
+ elif condition == ">=":
86
+ filtered = df[df[column] >= float(value)]
87
+ elif condition == "<=":
88
+ filtered = df[df[column] <= float(value)]
89
+ elif condition == "!=":
90
+ filtered = df[df[column] != value]
91
+ elif condition.lower() == "contains":
92
+ filtered = df[df[column].astype(str).str.contains(str(value))]
93
+ else:
94
+ return {"error": f"Unsupported condition: {condition}"}
95
+
96
+ return {
97
+ "result_count": len(filtered),
98
+ "results": filtered.head(10).to_dict(orient="records"),
99
+ "total_count": len(df)
100
+ }
101
+
102
+ def group_and_aggregate(self, filename: str, group_by: str, agg_column: str,
103
+ agg_function: str = "mean") -> Dict[str, Any]:
104
+ """Group by column and calculate aggregate statistic."""
105
+ df = self._load_dataframe(filename)
106
+
107
+ agg_functions = {
108
+ "mean": np.mean,
109
+ "sum": np.sum,
110
+ "min": np.min,
111
+ "max": np.max,
112
+ "count": len,
113
+ "median": np.median
114
+ }
115
+
116
+ if agg_function not in agg_functions:
117
+ return {"error": f"Unsupported aggregation function: {agg_function}"}
118
+
119
+ grouped = df.groupby(group_by)[agg_column].agg(agg_functions[agg_function])
120
+
121
+ return {
122
+ "group_by": group_by,
123
+ "aggregated_column": agg_column,
124
+ "aggregation": agg_function,
125
+ "results": grouped.to_dict()
126
+ }
127
+
128
+ def sort_data(self, filename: str, sort_by: str, ascending: bool = True) -> Dict[str, Any]:
129
+ """Sort data by column."""
130
+ df = self._load_dataframe(filename)
131
+
132
+ sorted_df = df.sort_values(by=sort_by, ascending=ascending)
133
+
134
+ return {
135
+ "sorted_by": sort_by,
136
+ "ascending": ascending,
137
+ "results": sorted_df.head(10).to_dict(orient="records")
138
+ }
139
+
140
+ def calculate_correlation(self, filename: str, column1: str, column2: str) -> Dict[str, Any]:
141
+ """Calculate correlation between two columns."""
142
+ df = self._load_dataframe(filename)
143
+
144
+ try:
145
+ correlation = df[column1].corr(df[column2])
146
+ return {
147
+ "correlation": correlation,
148
+ "column1": column1,
149
+ "column2": column2
150
+ }
151
+ except Exception as e:
152
+ return {"error": f"Could not calculate correlation: {str(e)}"}