File size: 2,196 Bytes
fd7f144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
import requests
import pandas as pd
import matplotlib.pyplot as plt
from io import StringIO
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("GEMINI_API_KEY")
API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"

def get_df_info(df):
    return f"""Columns: {list(df.columns)}
Sample data:
{df.head(3).to_string()}
Dtypes: {df.dtypes.to_dict()}"""

def call_gemini(prompt):
    headers = {"Content-Type": "application/json", "x-goog-api-key": API_KEY}
    data = {"contents": [{"parts": [{"text": prompt}]}]}
    response = requests.post(API_URL, headers=headers, json=data)
    return response.json()["candidates"][0]["content"]["parts"][0]["text"]

def query_csv(df, user_query):
    df_info = get_df_info(df)

    prompt = f"""You are a data analyst. Given this dataframe info:
{df_info}

User question: {user_query}

Write Python code using pandas to answer this. The dataframe is called 'df'.
IMPORTANT: Always assign your final answer to a variable called 'result'.
Whenever possible, show both text result AND a plot for better visualization.
Use matplotlib for plots (don't call plt.show()).
Only output Python code, no markdown, no explanation."""

    code = call_gemini(prompt)
    code = code.replace("```python", "").replace("```", "").strip()

    local_vars = {"df": df.copy(), "pd": pd, "plt": plt}
    fig = None

    old_stdout = sys.stdout
    sys.stdout = StringIO()

    try:
        plt.close('all')
        exec(code, local_vars)

        printed = sys.stdout.getvalue()
        sys.stdout = old_stdout

        if plt.get_fignums():
            plt.gcf().savefig('plot.png', dpi=100, bbox_inches='tight')
            fig = 'plot.png'
        plt.close('all')

        result = local_vars.get("result", None)

        if result is None and printed:
            result = printed.strip()

        if isinstance(result, (pd.DataFrame, pd.Series)):
            result = result.to_string()

        return str(result) if result else "Query executed.", fig, code
    except Exception as e:
        sys.stdout = old_stdout
        plt.close('all')
        return f"Error: {e}", None, code