| | from typing import Dict, List, Any, Optional, Tuple, Union |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | import matplotlib |
| | import io |
| | import base64 |
| | import numpy as np |
| | from pathlib import Path |
| |
|
| | |
| | matplotlib.use('Agg') |
| |
|
| | class VisualizationTools: |
| | """Tools for creating visualizations from CSV data.""" |
| | |
| | def __init__(self, csv_directory: str): |
| | """Initialize with directory containing CSV files.""" |
| | self.csv_directory = csv_directory |
| | self.dataframes = {} |
| | self.figure_size = (10, 6) |
| | self.dpi = 100 |
| | |
| | def _load_dataframe(self, filename: str) -> pd.DataFrame: |
| | """Load a CSV file as DataFrame, with caching.""" |
| | if filename not in self.dataframes: |
| | file_path = Path(self.csv_directory) / filename |
| | if not file_path.exists() and not filename.endswith('.csv'): |
| | file_path = Path(self.csv_directory) / f"{filename}.csv" |
| | |
| | if file_path.exists(): |
| | self.dataframes[filename] = pd.read_csv(file_path) |
| | else: |
| | raise ValueError(f"CSV file not found: {filename}") |
| | |
| | return self.dataframes[filename] |
| | |
| | def get_tools(self) -> List[Dict[str, Any]]: |
| | """Get all available visualization tools.""" |
| | tools = [ |
| | { |
| | "name": "create_line_chart", |
| | "description": "Create a line chart from CSV data", |
| | "function": self.create_line_chart |
| | }, |
| | { |
| | "name": "create_bar_chart", |
| | "description": "Create a bar chart from CSV data", |
| | "function": self.create_bar_chart |
| | }, |
| | { |
| | "name": "create_scatter_plot", |
| | "description": "Create a scatter plot from CSV data", |
| | "function": self.create_scatter_plot |
| | }, |
| | { |
| | "name": "create_histogram", |
| | "description": "Create a histogram from CSV data", |
| | "function": self.create_histogram |
| | }, |
| | { |
| | "name": "create_pie_chart", |
| | "description": "Create a pie chart from CSV data", |
| | "function": self.create_pie_chart |
| | } |
| | ] |
| | return tools |
| | |
| | def _figure_to_base64(self, fig) -> str: |
| | """Convert matplotlib figure to base64 encoded string.""" |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format='png', dpi=self.dpi) |
| | buf.seek(0) |
| | img_str = base64.b64encode(buf.read()).decode('utf-8') |
| | plt.close(fig) |
| | return img_str |
| | |
| | |
| | def create_line_chart(self, filename: str, x_column: str, y_column: str, |
| | title: str = None, limit: int = 50) -> Dict[str, Any]: |
| | """Create a line chart visualization.""" |
| | df = self._load_dataframe(filename) |
| | |
| | |
| | if len(df) > limit: |
| | df = df.head(limit) |
| | |
| | fig, ax = plt.subplots(figsize=self.figure_size) |
| | |
| | |
| | ax.plot(df[x_column], df[y_column], marker='o', linestyle='-') |
| | |
| | |
| | ax.set_xlabel(x_column) |
| | ax.set_ylabel(y_column) |
| | ax.set_title(title or f"{y_column} vs {x_column}") |
| | ax.grid(True) |
| | |
| | |
| | img_str = self._figure_to_base64(fig) |
| | |
| | return { |
| | "chart_type": "line", |
| | "x_column": x_column, |
| | "y_column": y_column, |
| | "data_points": len(df), |
| | "image": img_str |
| | } |
| | |
| | def create_bar_chart(self, filename: str, x_column: str, y_column: str, |
| | title: str = None, limit: int = 20) -> Dict[str, Any]: |
| | """Create a bar chart visualization.""" |
| | df = self._load_dataframe(filename) |
| | |
| | |
| | if len(df) > limit: |
| | df = df.head(limit) |
| | |
| | fig, ax = plt.subplots(figsize=self.figure_size) |
| | |
| | |
| | ax.bar(df[x_column], df[y_column]) |
| | |
| | |
| | ax.set_xlabel(x_column) |
| | ax.set_ylabel(y_column) |
| | ax.set_title(title or f"{y_column} by {x_column}") |
| | |
| | |
| | if len(df) > 5: |
| | plt.xticks(rotation=45, ha='right') |
| | |
| | plt.tight_layout() |
| | |
| | |
| | img_str = self._figure_to_base64(fig) |
| | |
| | return { |
| | "chart_type": "bar", |
| | "x_column": x_column, |
| | "y_column": y_column, |
| | "categories": len(df), |
| | "image": img_str |
| | } |
| | |
| | def create_scatter_plot(self, filename: str, x_column: str, y_column: str, |
| | color_column: str = None, title: str = None) -> Dict[str, Any]: |
| | """Create a scatter plot visualization.""" |
| | df = self._load_dataframe(filename) |
| | |
| | fig, ax = plt.subplots(figsize=self.figure_size) |
| | |
| | |
| | if color_column and color_column in df.columns: |
| | scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7) |
| | plt.colorbar(scatter, ax=ax, label=color_column) |
| | else: |
| | ax.scatter(df[x_column], df[y_column], alpha=0.7) |
| | |
| | |
| | ax.set_xlabel(x_column) |
| | ax.set_ylabel(y_column) |
| | ax.set_title(title or f"{y_column} vs {x_column}") |
| | ax.grid(True, linestyle='--', alpha=0.7) |
| | |
| | |
| | img_str = self._figure_to_base64(fig) |
| | |
| | return { |
| | "chart_type": "scatter", |
| | "x_column": x_column, |
| | "y_column": y_column, |
| | "color_column": color_column, |
| | "data_points": len(df), |
| | "image": img_str |
| | } |
| | |
| | def create_histogram(self, filename: str, column: str, bins: int = 10, |
| | title: str = None) -> Dict[str, Any]: |
| | """Create a histogram visualization.""" |
| | df = self._load_dataframe(filename) |
| | |
| | fig, ax = plt.subplots(figsize=self.figure_size) |
| | |
| | |
| | ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black') |
| | |
| | |
| | ax.set_xlabel(column) |
| | ax.set_ylabel('Frequency') |
| | ax.set_title(title or f"Distribution of {column}") |
| | ax.grid(True, linestyle='--', alpha=0.7) |
| | |
| | |
| | img_str = self._figure_to_base64(fig) |
| | |
| | return { |
| | "chart_type": "histogram", |
| | "column": column, |
| | "bins": bins, |
| | "data_points": len(df), |
| | "image": img_str |
| | } |
| | |
| | def create_pie_chart(self, filename: str, label_column: str, value_column: str = None, |
| | title: str = None, limit: int = 10) -> Dict[str, Any]: |
| | """Create a pie chart visualization.""" |
| | df = self._load_dataframe(filename) |
| | |
| | |
| | if value_column is None: |
| | data = df[label_column].value_counts().head(limit) |
| | labels = data.index.tolist() |
| | values = data.values.tolist() |
| | else: |
| | |
| | grouped = df.groupby(label_column)[value_column].sum().reset_index() |
| | |
| | grouped = grouped.nlargest(limit, value_column) |
| | labels = grouped[label_column].tolist() |
| | values = grouped[value_column].tolist() |
| | |
| | fig, ax = plt.subplots(figsize=self.figure_size) |
| | |
| | |
| | ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True) |
| | ax.axis('equal') |
| | |
| | |
| | ax.set_title(title or f"Distribution of {label_column}") |
| | |
| | |
| | img_str = self._figure_to_base64(fig) |
| | |
| | return { |
| | "chart_type": "pie", |
| | "label_column": label_column, |
| | "value_column": value_column, |
| | "categories": len(labels), |
| | "image": img_str |
| | } |
| |
|