amirkiarafiei commited on
Commit
a2134eb
·
1 Parent(s): 6e75e8f

feat: example script for pandas AI

Browse files
.gitignore CHANGED
@@ -1,6 +1,8 @@
1
  .idea
2
  .env
3
  .vscode
 
 
4
 
5
  # Byte-compiled / optimized / DLL files
6
  __pycache__/
 
1
  .idea
2
  .env
3
  .vscode
4
+ .xml
5
+ .iml
6
 
7
  # Byte-compiled / optimized / DLL files
8
  __pycache__/
exports/charts/temp_chart_d2455884-7b1b-4dd5-8e9a-8d928ec9628b.png ADDED
pandasai_visualization.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Visualization script using PandasAI.
4
+
5
+ This script creates a sample dataframe and uses PandasAI to generate
6
+ and save visualizations based on user queries.
7
+
8
+ Usage:
9
+ python visualize.py "Create a bar chart of sales by region"
10
+
11
+ Requirements:
12
+ - pandas
13
+ - pandasai
14
+ - matplotlib
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import pandas as pd
20
+ import matplotlib.pyplot as plt
21
+ import pandasai as pai
22
+ from dotenv import load_dotenv
23
+
24
+
25
+ def create_sample_dataframe():
26
+ """Create a sample dataframe with sales data."""
27
+ data = {
28
+ 'Region': ['North', 'South', 'East', 'West', 'North', 'South', 'East', 'West'],
29
+ 'Product': ['Widget', 'Widget', 'Widget', 'Widget', 'Gadget', 'Gadget', 'Gadget', 'Gadget'],
30
+ 'Sales': [150, 200, 120, 180, 90, 110, 95, 130],
31
+ 'Quarter': ['Q1', 'Q1', 'Q1', 'Q1', 'Q2', 'Q2', 'Q2', 'Q2'],
32
+ 'Year': [2023, 2023, 2023, 2023, 2023, 2023, 2023, 2023]
33
+ }
34
+ return pai.DataFrame(data)
35
+
36
+
37
+ def visualize_data(df, query):
38
+ """
39
+ Generate visualization based on user query using PandasAI.
40
+
41
+ Args:
42
+ df: Pandas DataFrame containing the data
43
+ query: User query string describing the desired visualization
44
+
45
+ Returns:
46
+ Path to the saved visualization file
47
+ """
48
+ # Initialize PandasAI with an LLM
49
+ # Note: In a real application, you would need to set up your OpenAI API key
50
+ # Either set OPENAI_API_KEY environment variable or pass it directly
51
+ try:
52
+
53
+ # llm = OpenAI(api_token=api_key)
54
+ # pandas_ai = PandasAI(llm)
55
+
56
+ load_dotenv()
57
+ pai.api_key.set(os.environ["PANDAS_KEY"])
58
+
59
+ df.chat(query)
60
+
61
+ # Generate the visualization
62
+ print(f"Generating visualization for query: '{query}'")
63
+
64
+ # Save the current figure
65
+ output_file = "visualization_output.png"
66
+ plt.savefig(output_file)
67
+ plt.close()
68
+
69
+ print(f"Visualization saved to {output_file}")
70
+ return output_file
71
+
72
+ except Exception as e:
73
+ print(f"Error generating visualization: {str(e)}")
74
+ return None
75
+
76
+
77
+ def main():
78
+ """Main function to run the visualization script."""
79
+ # Get query from command line argument
80
+ # if len(sys.argv) < 2:
81
+ # print("Usage: python visualize.py \"Your visualization query here\"")
82
+ # print("Example: python visualize.py \"Create a bar chart of sales by region\"")
83
+ # return
84
+
85
+ # query = sys.argv[1]
86
+ query = "Plot a bar chart of sales by region"
87
+
88
+ # Create sample dataframe
89
+ df = create_sample_dataframe()
90
+ print("Sample DataFrame created:")
91
+ print(df.head())
92
+
93
+ # Generate and save visualization
94
+ output_file = visualize_data(df, query)
95
+
96
+ if output_file:
97
+ print(f"Visualization process completed. Output saved to: {output_file}")
98
+ else:
99
+ print("Visualization process failed.")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
postgre_mcp_server.py CHANGED
@@ -2,11 +2,13 @@ import os
2
  from contextlib import asynccontextmanager
3
  from dataclasses import dataclass
4
  from typing import Optional, AsyncIterator
5
-
6
  import asyncpg
7
- from flask.cli import load_dotenv
8
  from mcp.server.fastmcp import FastMCP, Context
9
  from pydantic import Field
 
 
 
10
 
11
  # Constants
12
  DEFAULT_QUERY_LIMIT = 100
@@ -578,5 +580,67 @@ def find_relationships(table_name: str, schema: str = 'public') -> str:
578
  return f"Error finding relationships: {str(e)}"
579
 
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  if __name__ == "__main__":
582
  mcp.run()
 
2
  from contextlib import asynccontextmanager
3
  from dataclasses import dataclass
4
  from typing import Optional, AsyncIterator
 
5
  import asyncpg
6
+ from dotenv import load_dotenv
7
  from mcp.server.fastmcp import FastMCP, Context
8
  from pydantic import Field
9
+ import pandasai as pai
10
+ import matplotlib as plt
11
+ import pandas as pd
12
 
13
  # Constants
14
  DEFAULT_QUERY_LIMIT = 100
 
580
  return f"Error finding relationships: {str(e)}"
581
 
582
 
583
+ @mcp.tool(description="Visualizes query results using a prompt and JSON data.")
584
+ async def visualize_results(json_data: dict, vis_prompt: str) -> str:
585
+ """
586
+ Generates a visualization based on query results using PandasAI.
587
+
588
+ Args:
589
+ json_data (dict): A dictionary containing the query results.
590
+ It should have two keys:
591
+ - 'columns': A list of column names (strings).
592
+ - 'data': A list of lists, where each inner list represents a row of data.
593
+ Each element in the inner list corresponds to a column in 'columns'.
594
+ Example:
595
+ {
596
+ 'columns': ['Region', 'Product', 'Sales'],
597
+ 'data': [
598
+ ['North', 'Widget', 150],
599
+ ['South', 'Widget', 200]
600
+ ]
601
+ }
602
+ vis_prompt (str): A natural language prompt describing the desired visualization
603
+ (e.g., "Create a bar chart showing sales by region").
604
+
605
+ Returns:
606
+ str: The path to the saved visualization file (e.g., 'visualization_output.png')
607
+ or an error message if the visualization fails.
608
+ """
609
+ try:
610
+ # Debug prints to see what's being received
611
+ print("\nVisualization Tool Debug:")
612
+ print(f"Received json_data: {json_data}")
613
+ print(f"Received vis_prompt: {vis_prompt}")
614
+
615
+ # Convert JSON to DataFrame
616
+ df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
617
+ print(f"Created DataFrame:\n{df.head()}")
618
+
619
+ # Initialize PandasAI
620
+ df_ai = pai.DataFrame(df)
621
+ print("Initialized PandasAI DataFrame")
622
+
623
+ load_dotenv()
624
+ api_key = os.environ.get("PANDAS_KEY")
625
+ print(f"Using PandasAI API key: {api_key[:5]}...")
626
+ pai.api_key.set(api_key)
627
+
628
+ # Generate visualization
629
+ print(f"Attempting to generate visualization with prompt: '{vis_prompt}'")
630
+ df_ai.chat(vis_prompt)
631
+
632
+ # Save plot
633
+ output_file = "visualization_output.png"
634
+ plt.savefig(output_file)
635
+ plt.close()
636
+ print(f"Saved visualization to {output_file}")
637
+
638
+ return f"Visualization saved as {output_file}"
639
+ except Exception as e:
640
+ print(f"Visualization error: {str(e)}")
641
+ print(f"Error type: {type(e)}")
642
+ return f"Visualization error: {str(e)}"
643
+
644
+
645
  if __name__ == "__main__":
646
  mcp.run()
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ