|
|
import json |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import gradio as gr |
|
|
from openai import OpenAI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_openai_client(api_key: str): |
|
|
return OpenAI(api_key=api_key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_and_normalize_csv(csv_file): |
|
|
df = pd.read_csv(csv_file) |
|
|
|
|
|
|
|
|
df.columns = ( |
|
|
df.columns |
|
|
.str.strip() |
|
|
.str.lower() |
|
|
.str.replace(" ", "") |
|
|
.str.replace("/", "") |
|
|
.str.replace("_", "") |
|
|
) |
|
|
|
|
|
|
|
|
if "drcr" in df.columns: |
|
|
df["drcr"] = ( |
|
|
df["drcr"] |
|
|
.astype(str) |
|
|
.str.strip() |
|
|
.str.lower() |
|
|
.replace({ |
|
|
"cr": "credit", |
|
|
"db": "debit" |
|
|
}) |
|
|
) |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_intent(question: str, api_key: str) -> dict: |
|
|
client = get_openai_client(api_key) |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
response_format={"type": "json_object"}, |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"You are a data analysis planner.\n" |
|
|
"Return ONLY valid JSON.\n" |
|
|
"Do NOT explain.\n\n" |
|
|
"JSON format:\n" |
|
|
"{\n" |
|
|
' "action": "count | sum | plot",\n' |
|
|
' "filters": { "year": number | null, "drcr": string | null },\n' |
|
|
' "groupby": "year | drcr | null"\n' |
|
|
"}" |
|
|
) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": question |
|
|
} |
|
|
] |
|
|
) |
|
|
|
|
|
return json.loads(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def execute_intent(intent: dict, df: pd.DataFrame): |
|
|
data = df.copy() |
|
|
|
|
|
|
|
|
filters = intent.get("filters", {}) |
|
|
for key, value in filters.items(): |
|
|
if value is not None and key in data.columns: |
|
|
data = data[data[key] == value] |
|
|
|
|
|
action = intent.get("action") |
|
|
group_col = intent.get("groupby") |
|
|
|
|
|
|
|
|
if action == "count": |
|
|
if group_col: |
|
|
return data.groupby(group_col).size() |
|
|
return len(data) |
|
|
|
|
|
|
|
|
if action == "sum": |
|
|
if group_col: |
|
|
return data.groupby(group_col)["amount"].sum() |
|
|
return data["amount"].sum() |
|
|
|
|
|
|
|
|
if action == "plot": |
|
|
if not group_col: |
|
|
raise ValueError("Plot requires groupby") |
|
|
|
|
|
result = data.groupby(group_col).size() |
|
|
result.plot(kind="bar") |
|
|
plt.title("Result") |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
return result |
|
|
|
|
|
raise ValueError(f"Unknown action: {action}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_question(question: str, api_key: str, df: pd.DataFrame): |
|
|
intent = get_intent(question, api_key) |
|
|
return execute_intent(intent, df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_answer(api_key, csv_file, question): |
|
|
try: |
|
|
if not api_key: |
|
|
return "Please provide your OpenAI API key." |
|
|
|
|
|
if csv_file is None: |
|
|
return "Please upload a CSV file." |
|
|
|
|
|
if not question: |
|
|
return "Please enter a question." |
|
|
|
|
|
df = load_and_normalize_csv(csv_file) |
|
|
result = answer_question(question, api_key, df) |
|
|
|
|
|
if hasattr(result, "to_string"): |
|
|
return result.to_string() |
|
|
|
|
|
return str(result) |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_answer, |
|
|
inputs=[ |
|
|
gr.Textbox(label="OpenAI API Key", type="password"), |
|
|
gr.File(label="Upload CSV File", file_types=[".csv"]), |
|
|
gr.Textbox( |
|
|
label="Ask a question about your CSV", |
|
|
placeholder="How many credit operations happened in 2022?" |
|
|
) |
|
|
], |
|
|
outputs=gr.Textbox(label="Answer"), |
|
|
title="Chat with your CSV 📊", |
|
|
description="Upload any CSV file and ask natural language questions about it" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |