File size: 2,095 Bytes
62c2e4c
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ea8b5
f63bd4c
 
62c2e4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
plotly_tools.py

Generates interactive plots based on user queries using LangChain SQL Agent and Plotly.
Saves plots as both HTML and PNG.

Author: Saivivek Katkuri
Date: June 2025
"""

import os
import plotly.express as px
import pandas as pd
from datetime import datetime
from langchain_community.utilities import SQLDatabase

from langchain_experimental.sql import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI

# Function to generate Plotly plot from SQL query
def generate_plot(query: str, db_path: str):
    """
    Generates a plot from a natural language query using LangChain SQL Agent and Plotly.

    Args:
        query (str): User query (e.g., "Show sales distribution by region").
        db_path (str): Path to SQLite database.

    Returns:
        str: Description of the plot.
        str: Path to HTML plot.
        str: Path to PNG plot.
    """
    try:
        # Setup LangChain SQL agent
        db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
        llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
        sql_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

        # Get SQL result
        sql_result = sql_chain.run(query)

        # Load result into DataFrame
        df = pd.read_sql_query(sql_result, db.engine)

        # Determine plot type (simple logic)
        if len(df.columns) == 2:
            fig = px.bar(df, x=df.columns[0], y=df.columns[1])
        elif len(df.columns) >= 3:
            fig = px.scatter(df, x=df.columns[0], y=df.columns[1], color=df.columns[2])
        else:
            fig = px.histogram(df)

        # Save plots
        os.makedirs("plots", exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        html_path = f"plots/plot_{timestamp}.html"
        png_path = f"plots/plot_{timestamp}.png"
        fig.write_html(html_path)
        fig.write_image(png_path)

        return f"✅ Generated plot based on your query. Columns: {', '.join(df.columns)}", html_path, png_path

    except Exception as e:
        return f"❌ Error generating plot: {str(e)}", None, None