""" plotly_tools.py Generates interactive plots based on user queries using LangChain SQL Agent and Plotly. Saves plots as both HTML and PNG. Author: Saivivek Katkuri Date: June 2025 """ import os import plotly.express as px import pandas as pd from datetime import datetime from langchain_community.utilities import SQLDatabase from langchain_experimental.sql import SQLDatabaseChain from langchain.chat_models import ChatOpenAI # Function to generate Plotly plot from SQL query def generate_plot(query: str, db_path: str): """ Generates a plot from a natural language query using LangChain SQL Agent and Plotly. Args: query (str): User query (e.g., "Show sales distribution by region"). db_path (str): Path to SQLite database. Returns: str: Description of the plot. str: Path to HTML plot. str: Path to PNG plot. """ try: # Setup LangChain SQL agent db = SQLDatabase.from_uri(f"sqlite:///{db_path}") llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) sql_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) # Get SQL result sql_result = sql_chain.run(query) # Load result into DataFrame df = pd.read_sql_query(sql_result, db.engine) # Determine plot type (simple logic) if len(df.columns) == 2: fig = px.bar(df, x=df.columns[0], y=df.columns[1]) elif len(df.columns) >= 3: fig = px.scatter(df, x=df.columns[0], y=df.columns[1], color=df.columns[2]) else: fig = px.histogram(df) # Save plots os.makedirs("plots", exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") html_path = f"plots/plot_{timestamp}.html" png_path = f"plots/plot_{timestamp}.png" fig.write_html(html_path) fig.write_image(png_path) return f"✅ Generated plot based on your query. Columns: {', '.join(df.columns)}", html_path, png_path except Exception as e: return f"❌ Error generating plot: {str(e)}", None, None