| | from typing import Dict, List, Any, Optional, Tuple, Union
|
| | import pandas as pd
|
| | import matplotlib.pyplot as plt
|
| | import matplotlib
|
| | import io
|
| | import base64
|
| | import numpy as np
|
| | from llama_index.tools import FunctionTool
|
| | from pathlib import Path
|
| |
|
| |
|
| | matplotlib.use('Agg')
|
| |
|
| | class VisualizationTools:
|
| | """Tools for creating visualizations from CSV data."""
|
| |
|
| | def __init__(self, csv_directory: str):
|
| | """Initialize with directory containing CSV files."""
|
| | self.csv_directory = csv_directory
|
| | self.dataframes = {}
|
| | self.tools = self._create_tools()
|
| | self.figure_size = (10, 6)
|
| | self.dpi = 100
|
| |
|
| | def _load_dataframe(self, filename: str) -> pd.DataFrame:
|
| | """Load a CSV file as DataFrame, with caching."""
|
| | if filename not in self.dataframes:
|
| | file_path = Path(self.csv_directory) / filename
|
| | if not file_path.exists() and not filename.endswith('.csv'):
|
| | file_path = Path(self.csv_directory) / f"{filename}.csv"
|
| |
|
| | if file_path.exists():
|
| | self.dataframes[filename] = pd.read_csv(file_path)
|
| | else:
|
| | raise ValueError(f"CSV file not found: {filename}")
|
| |
|
| | return self.dataframes[filename]
|
| |
|
| | def _create_tools(self) -> List[FunctionTool]:
|
| | """Create LlamaIndex function tools for visualizations."""
|
| | tools = [
|
| | FunctionTool.from_defaults(
|
| | name="create_line_chart",
|
| | description="Create a line chart from CSV data",
|
| | fn=self.create_line_chart
|
| | ),
|
| | FunctionTool.from_defaults(
|
| | name="create_bar_chart",
|
| | description="Create a bar chart from CSV data",
|
| | fn=self.create_bar_chart
|
| | ),
|
| | FunctionTool.from_defaults(
|
| | name="create_scatter_plot",
|
| | description="Create a scatter plot from CSV data",
|
| | fn=self.create_scatter_plot
|
| | ),
|
| | FunctionTool.from_defaults(
|
| | name="create_histogram",
|
| | description="Create a histogram from CSV data",
|
| | fn=self.create_histogram
|
| | ),
|
| | FunctionTool.from_defaults(
|
| | name="create_pie_chart",
|
| | description="Create a pie chart from CSV data",
|
| | fn=self.create_pie_chart
|
| | )
|
| | ]
|
| | return tools
|
| |
|
| | def get_tools(self) -> List[FunctionTool]:
|
| | """Get all available visualization tools."""
|
| | return self.tools
|
| |
|
| | def _figure_to_base64(self, fig) -> str:
|
| | """Convert matplotlib figure to base64 encoded string."""
|
| | buf = io.BytesIO()
|
| | fig.savefig(buf, format='png', dpi=self.dpi)
|
| | buf.seek(0)
|
| | img_str = base64.b64encode(buf.read()).decode('utf-8')
|
| | plt.close(fig)
|
| | return img_str
|
| |
|
| |
|
| | def create_line_chart(self, filename: str, x_column: str, y_column: str,
|
| | title: str = None, limit: int = 50) -> Dict[str, Any]:
|
| | """Create a line chart visualization."""
|
| | df = self._load_dataframe(filename)
|
| |
|
| |
|
| | if len(df) > limit:
|
| | df = df.head(limit)
|
| |
|
| | fig, ax = plt.subplots(figsize=self.figure_size)
|
| |
|
| |
|
| | ax.plot(df[x_column], df[y_column], marker='o', linestyle='-')
|
| |
|
| |
|
| | ax.set_xlabel(x_column)
|
| | ax.set_ylabel(y_column)
|
| | ax.set_title(title or f"{y_column} vs {x_column}")
|
| | ax.grid(True)
|
| |
|
| |
|
| | img_str = self._figure_to_base64(fig)
|
| |
|
| | return {
|
| | "chart_type": "line",
|
| | "x_column": x_column,
|
| | "y_column": y_column,
|
| | "data_points": len(df),
|
| | "image": img_str
|
| | }
|
| |
|
| | def create_bar_chart(self, filename: str, x_column: str, y_column: str,
|
| | title: str = None, limit: int = 20) -> Dict[str, Any]:
|
| | """Create a bar chart visualization."""
|
| | df = self._load_dataframe(filename)
|
| |
|
| |
|
| | if len(df) > limit:
|
| | df = df.head(limit)
|
| |
|
| | fig, ax = plt.subplots(figsize=self.figure_size)
|
| |
|
| |
|
| | ax.bar(df[x_column], df[y_column])
|
| |
|
| |
|
| | ax.set_xlabel(x_column)
|
| | ax.set_ylabel(y_column)
|
| | ax.set_title(title or f"{y_column} by {x_column}")
|
| |
|
| |
|
| | if len(df) > 5:
|
| | plt.xticks(rotation=45, ha='right')
|
| |
|
| | plt.tight_layout()
|
| |
|
| |
|
| | img_str = self._figure_to_base64(fig)
|
| |
|
| | return {
|
| | "chart_type": "bar",
|
| | "x_column": x_column,
|
| | "y_column": y_column,
|
| | "categories": len(df),
|
| | "image": img_str
|
| | }
|
| |
|
| | def create_scatter_plot(self, filename: str, x_column: str, y_column: str,
|
| | color_column: str = None, title: str = None) -> Dict[str, Any]:
|
| | """Create a scatter plot visualization."""
|
| | df = self._load_dataframe(filename)
|
| |
|
| | fig, ax = plt.subplots(figsize=self.figure_size)
|
| |
|
| |
|
| | if color_column and color_column in df.columns:
|
| | scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7)
|
| | plt.colorbar(scatter, ax=ax, label=color_column)
|
| | else:
|
| | ax.scatter(df[x_column], df[y_column], alpha=0.7)
|
| |
|
| |
|
| | ax.set_xlabel(x_column)
|
| | ax.set_ylabel(y_column)
|
| | ax.set_title(title or f"{y_column} vs {x_column}")
|
| | ax.grid(True, linestyle='--', alpha=0.7)
|
| |
|
| |
|
| | img_str = self._figure_to_base64(fig)
|
| |
|
| | return {
|
| | "chart_type": "scatter",
|
| | "x_column": x_column,
|
| | "y_column": y_column,
|
| | "color_column": color_column,
|
| | "data_points": len(df),
|
| | "image": img_str
|
| | }
|
| |
|
| | def create_histogram(self, filename: str, column: str, bins: int = 10,
|
| | title: str = None) -> Dict[str, Any]:
|
| | """Create a histogram visualization."""
|
| | df = self._load_dataframe(filename)
|
| |
|
| | fig, ax = plt.subplots(figsize=self.figure_size)
|
| |
|
| |
|
| | ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black')
|
| |
|
| |
|
| | ax.set_xlabel(column)
|
| | ax.set_ylabel('Frequency')
|
| | ax.set_title(title or f"Distribution of {column}")
|
| | ax.grid(True, linestyle='--', alpha=0.7)
|
| |
|
| |
|
| | img_str = self._figure_to_base64(fig)
|
| |
|
| | return {
|
| | "chart_type": "histogram",
|
| | "column": column,
|
| | "bins": bins,
|
| | "data_points": len(df),
|
| | "image": img_str
|
| | }
|
| |
|
| | def create_pie_chart(self, filename: str, label_column: str, value_column: str = None,
|
| | title: str = None, limit: int = 10) -> Dict[str, Any]:
|
| | """Create a pie chart visualization."""
|
| | df = self._load_dataframe(filename)
|
| |
|
| |
|
| | if value_column is None:
|
| | data = df[label_column].value_counts().head(limit)
|
| | labels = data.index.tolist()
|
| | values = data.values.tolist()
|
| | else:
|
| |
|
| | grouped = df.groupby(label_column)[value_column].sum().reset_index()
|
| |
|
| | grouped = grouped.nlargest(limit, value_column)
|
| | labels = grouped[label_column].tolist()
|
| | values = grouped[value_column].tolist()
|
| |
|
| | fig, ax = plt.subplots(figsize=self.figure_size)
|
| |
|
| |
|
| | ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True)
|
| | ax.axis('equal')
|
| |
|
| |
|
| | ax.set_title(title or f"Distribution of {label_column}")
|
| |
|
| |
|
| | img_str = self._figure_to_base64(fig)
|
| |
|
| | return {
|
| | "chart_type": "pie",
|
| | "label_column": label_column,
|
| | "value_column": value_column,
|
| | "categories": len(labels),
|
| | "image": img_str
|
| | }
|
| |
|