| """ |
| plotly_tools.py |
| |
| Generates interactive plots based on user queries using LangChain SQL Agent and Plotly. |
| Saves plots as both HTML and PNG. |
| |
| Author: Saivivek Katkuri |
| Date: June 2025 |
| """ |
|
|
| import os |
| import plotly.express as px |
| import pandas as pd |
| from datetime import datetime |
| from langchain_community.utilities import SQLDatabase |
|
|
| from langchain_experimental.sql import SQLDatabaseChain |
| from langchain.chat_models import ChatOpenAI |
|
|
| |
| def generate_plot(query: str, db_path: str): |
| """ |
| Generates a plot from a natural language query using LangChain SQL Agent and Plotly. |
| |
| Args: |
| query (str): User query (e.g., "Show sales distribution by region"). |
| db_path (str): Path to SQLite database. |
| |
| Returns: |
| str: Description of the plot. |
| str: Path to HTML plot. |
| str: Path to PNG plot. |
| """ |
| try: |
| |
| db = SQLDatabase.from_uri(f"sqlite:///{db_path}") |
| llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) |
| sql_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) |
|
|
| |
| sql_result = sql_chain.run(query) |
|
|
| |
| df = pd.read_sql_query(sql_result, db.engine) |
|
|
| |
| if len(df.columns) == 2: |
| fig = px.bar(df, x=df.columns[0], y=df.columns[1]) |
| elif len(df.columns) >= 3: |
| fig = px.scatter(df, x=df.columns[0], y=df.columns[1], color=df.columns[2]) |
| else: |
| fig = px.histogram(df) |
|
|
| |
| os.makedirs("plots", exist_ok=True) |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
| html_path = f"plots/plot_{timestamp}.html" |
| png_path = f"plots/plot_{timestamp}.png" |
| fig.write_html(html_path) |
| fig.write_image(png_path) |
|
|
| return f"✅ Generated plot based on your query. Columns: {', '.join(df.columns)}", html_path, png_path |
|
|
| except Exception as e: |
| return f"❌ Error generating plot: {str(e)}", None, None |