File size: 2,232 Bytes
472e1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sqlite3
import pandas as pd
from langchain.tools import BaseTool
import os
import matplotlib.pyplot as plt
from utils.consts import DB_PATH, PLOTS_DIR

# Fetch table list
_conn = sqlite3.connect(DB_PATH)
_TABLES = [row[0] for row in _conn.execute("SELECT name FROM sqlite_master WHERE type='table';")]
_conn.close()
_TABLES_LIST = ", ".join(_TABLES)

class SQLiteQueryTool(BaseTool):
    name: str = "sqlite_query"
    description: str = f"Executes a SQL query against the ecommerce SQLite database and returns results as CSV. Available tables: {_TABLES_LIST}."

    def _run(self, query: str) -> str:
        print(f"[SQLiteQueryTool] Executing query: {query}")
        conn = sqlite3.connect(DB_PATH)
        try:
            df = pd.read_sql_query(query, conn)
            return df.to_csv(index=False)
        except Exception as e:
            return f"SQL Error: {e}"
        finally:
            conn.close()

    async def _arun(self, query: str) -> str:
        raise NotImplementedError("Async not supported for SQLiteQueryTool")


class PlotSQLTool(BaseTool):
    name: str = "plot_sql"
    description: str = f"Executes a SQL query and generates a plot saved as a PNG; returns markdown image link. Available tables: {_TABLES_LIST}."

    def _run(self, query: str) -> str:
        print(f"[PlotSQLTool] Executing query: {query}")
        conn = sqlite3.connect(DB_PATH)
        try:
            df = pd.read_sql_query(query, conn)
            plt.figure()
            df.plot(kind='bar' if df.shape[1] > 1 else 'line', legend=False)
            timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
            filename = f"plot_{timestamp}.png"
            # Save plot to configured output directory
            output_dir = PLOTS_DIR
            os.makedirs(output_dir, exist_ok=True)
            filepath = os.path.join(output_dir, filename)
            plt.tight_layout()
            plt.savefig(filepath)
            plt.close()
            return f"![Plot]({filepath})"
        except Exception as e:
            return f"Plot Error: {e}"
        finally:
            conn.close()

    async def _arun(self, query: str) -> str:
        raise NotImplementedError("Async not supported for PlotSQLTool")