File size: 8,298 Bytes
ed21fd7
216ce00
 
ed21fd7
 
0b69e41
a614a39
3a72d72
 
ed21fd7
 
0b69e41
ed21fd7
0b69e41
ed21fd7
 
 
 
0b69e41
ed21fd7
a614a39
 
0b69e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18b02aa
3a72d72
 
0b69e41
 
ed21fd7
 
0b69e41
 
 
 
18b02aa
a614a39
0b69e41
 
 
 
ed21fd7
 
0b69e41
 
 
 
 
 
 
ed21fd7
0b69e41
 
 
 
 
 
3a72d72
0b69e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed21fd7
 
20ed955
0b69e41
 
 
 
 
ed21fd7
 
 
 
0b69e41
 
 
216ce00
0b69e41
216ce00
0b69e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ed955
0b69e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ed955
0b69e41
 
 
 
 
3a72d72
 
0b69e41
55d1d72
 
 
 
0b69e41
55d1d72
0b69e41
18b02aa
 
 
0b69e41
 
 
20ed955
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import streamlit as st
import pandas as pd
import numpy as np
import requests
import json
import time 
import matplotlib.pyplot as plt 
import seaborn as sns 

# --- CONFIG ---
# Note: GEMINI_API_KEY is retrieved from environment variables/secrets.
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

if not GEMINI_API_KEY:
    st.error("❌ Missing Gemini API key. Add it as a secret: GEMINI_API_KEY")
    st.stop()

# Define API endpoints and models
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta"
# Using the correct model for structured output
CHAT_MODEL = "gemini-2.5-flash-preview-09-2025" 
EMBED_MODEL = "models/embedding-001" 

# Define the JSON schema for structured output
ANALYSIS_SCHEMA = {
    "type": "OBJECT",
    "properties": {
        "reasoning": {
            "type": "STRING",
            "description": "A detailed natural language explanation of the analysis, including key findings and context."
        },
        "code": {
            "type": "STRING",
            "description": "The complete, runnable Python code using pandas (df) and streamlit (st). Use st.pyplot() for plots, and st.dataframe() for resulting DataFrames. If no code is needed, this should be an empty string."
        }
    }
}

SYSTEM_INSTRUCTION = (
    "You are a world-class Data Analyst Agent. Your task is to analyze the provided DataFrame ('df') "
    "based on the user's question. You MUST respond with a single JSON object conforming to the provided schema. "
    "1. **Reasoning:** Explain your plan, the steps taken, and the insights derived from the data. Format this in Markdown. "
    "2. **Code:** If the question requires calculation, aggregation, or visualization, you MUST generate Python code to execute against the 'df' DataFrame. "
    "   - The DataFrame is already loaded as a variable named 'df'. Do NOT redefine it. "
    "   - Use Streamlit functions for simple outputs: `st.dataframe(...)`, `st.bar_chart()`, `st.line_chart()`. "
    "   - For **ALL** custom, complex plots, you MUST follow this strict Matplotlib sequence: **Start with `plt.figure()`, use `plt.` or `sns.` commands for plotting, and explicitly end with `st.pyplot(plt)`** to display the output. "
    "   - **CRITICAL GUARDRAIL:** When generating code that uses logical conditions (e.g., in `if` statements or for complex filters) on Pandas Series or NumPy arrays, you **MUST** resolve ambiguity by using `.any()` or `.all()`. Do NOT compare a series directly to a single boolean value."
    "   - Ensure the code is self-contained and ready to execute."
)

# --- Helper Functions ---

def chat_with_gemini(prompt, context):
    """Sends a prompt and data context to the Gemini model for structured analysis (reasoning + code)."""
    
    # Correctly prepend 'models/' to the model name in the URL path
    url = f"{GEMINI_BASE}/models/{CHAT_MODEL}:generateContent?key={GEMINI_API_KEY}"
    
    # Construct the full prompt including the data context
    full_prompt = f"Data Context (DataFrame Head and Columns):\n{context}\n\nUser Question: {prompt}"
    
    payload = {
        "contents": [
            {"parts": [{"text": full_prompt}]}
        ],
        "systemInstruction": {"parts": [{"text": SYSTEM_INSTRUCTION}]},
        "generationConfig": {
            "responseMimeType": "application/json",
            "responseSchema": ANALYSIS_SCHEMA
        }
    }
    
    max_retries = 5
    delay = 1
    for attempt in range(max_retries):
        try:
            r = requests.post(url, headers={'Content-Type': 'application/json'}, data=json.dumps(payload))
            r.raise_for_status() 
            data = r.json()
            
            json_str = data["candidates"][0]["content"]["parts"][0]["text"]
            return json.loads(json_str)

        except requests.exceptions.RequestException as e:
            if attempt < max_retries - 1:
                time.sleep(delay)
                delay *= 2
            else:
                st.error(f"API Request Failed: {e}")
                raise e
        except Exception as e:
            st.error(f"Failed to parse model response or execute operation: {e}")
            raise e

# --- UI ---
st.title("✨Data Analyst Agent (Code Execution Enabled)")
st.write("Upload a CSV file and ask natural language questions. The agent now generates and executes Python code to provide precise data analysis and visualizations.")

# State variable to hold the DataFrame, initialized once
if 'df' not in st.session_state:
    st.session_state.df = pd.DataFrame()

uploaded = st.file_uploader("Upload CSV", type=["csv"])

if uploaded:
    # Use st.cache_data to avoid reloading the file multiple times
    @st.cache_data
    def load_data(file):
        try:
            return pd.read_csv(file)
        except Exception as e:
            st.error(f"Failed to load CSV: {e}")
            return pd.DataFrame()
            
    st.session_state.df = load_data(uploaded)

    if not st.session_state.df.empty:
        st.subheader("Data Preview (First 5 Rows)")
        st.dataframe(st.session_state.df.head())
    
        question = st.text_area("Ask a complex question or request a visualization (e.g., 'Show the average of the 'Sales' column', 'Plot the distribution of 'Age'):")
    
        if st.button("Analyze & Execute") and question:
            df = st.session_state.df # Local variable for code execution context
            
            # Summarize dataset for context sent to the LLM
            context = f"Dataset Columns: {', '.join(df.columns.astype(str))}\n\nFirst 5 rows of data:\n{df.head(5).to_string(index=False)}"

            st.markdown("---")
            st.subheader("πŸ€– Analysis Steps")
            
            with st.spinner("1. Generating analysis plan and code..."):
                try:
                    # 1. Get structured response from LLM
                    analysis_result = chat_with_gemini(question, context)
                    
                    reasoning = analysis_result.get('reasoning', "No reasoning provided.")
                    code = analysis_result.get('code', "")
                    
                    st.markdown("#### πŸ’¬ Reasoning:")
                    st.markdown(reasoning)
                    
                    st.markdown("#### 🐍 Generated Code:")
                    st.code(code, language='python')
                    
                except Exception as e:
                    st.error(f"Step 1 Failed (LLM Interaction): {e}")
                    reasoning = ""
                    code = ""

            if code:
                with st.spinner("2. Executing code and generating output..."):
                    try:
                        # 2. Execute the generated Python code safely
                        
                        # IMPORTANT: Create a local scope with necessary variables
                        local_scope = {
                            'df': df, 
                            'st': st, 
                            'pd': pd,
                            'np': np,
                            'plt': plt,
                            'sns': sns, 
                        }
                        
                        # Append a neutral statement to the code to prevent implicit Streamlit display of the last value
                        final_code = code + "\nNone" 
                        
                        # Executing the code within the local scope
                        exec(final_code, globals(), local_scope)
                        
                        # FIX: Explicitly close all Matplotlib figures to prevent cross-run contamination
                        plt.close('all') 
                        
                        st.success("Code execution complete. Results are displayed above.")

                    except Exception as e:
                        st.error(f"Step 2 Failed (Code Execution Error): The agent generated invalid code. Check the console for full traceback.")
                        st.exception(e)
            else:
                st.info("No code was generated, as the question was purely informational.")
    else:
        st.info("The uploaded CSV file appears to be empty.")

else:
    st.info("πŸ‘† Upload a CSV file to begin the full analysis experience.")