csv-chat-app / app.py
amoghsuman's picture
Update app.py
fffd7fa verified
import gradio as gr
import pandas as pd
import os
import io
from langchain.llms import OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.agents.agent_types import AgentType
from langchain_experimental.agents import create_pandas_dataframe_agent
# Load API key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
# Global data stores
dataset_dict = {}
agent_dict = {}
def upload_csv(files):
global dataset_dict
messages = []
filenames = []
for file in files:
try:
fname = os.path.basename(file.name)
df = pd.read_csv(file)
dataset_dict[fname] = df
messages.append(f"βœ… {fname} uploaded. Shape: {df.shape}")
filenames.append(fname)
except Exception as e:
messages.append(f"❌ {file.name} failed: {e}")
return "\n".join(messages), gr.update(choices=filenames, value=filenames[0] if filenames else None)
def generate_summary(dataset_name):
df = dataset_dict[dataset_name]
df_info = io.StringIO()
df.info(buf=df_info)
info_str = df_info.getvalue()
value_counts_dict = {
col: df[col].value_counts().head(2).to_dict()
for col in df.select_dtypes(include='object').columns
}
value_counts_str = "\n".join(
f"{col}: {', '.join([f'{k} ({v})' for k, v in counts.items()])}"
for col, counts in value_counts_dict.items()
)
summary_prompt = PromptTemplate(
input_variables=["shape", "columns", "info", "describe", "value_counts"],
template="""
You are a data scientist. Given the following dataset info, generate a clear and concise summary in markdown format using the following structure:
**🧾 Overview:** 1–2 lines about the data
**πŸ” Key Insights:** 3–5 bullet points highlighting trends or distributions
**πŸ“Œ Notable Stats:** 2–3 bullet points covering variance, skew, or standout numerical stats
Dataset shape: {shape}
Column names: {columns}
Info: {info}
Describe: {describe}
Value counts: {value_counts}
"""
)
summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
return summary_chain.run({
"shape": str(df.shape),
"columns": ", ".join(df.columns),
"info": info_str,
"describe": df.describe().to_string(),
"value_counts": value_counts_str
})
def ask_csv(dataset_name, user_input, history):
df = dataset_dict[dataset_name]
if dataset_name not in agent_dict:
agent_dict[dataset_name] = create_pandas_dataframe_agent(
llm=llm,
df=df,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
handle_parsing_errors=True,
allow_dangerous_code=True
)
agent = agent_dict[dataset_name]
raw_output = agent.run(user_input)
refine_prompt = PromptTemplate(
input_variables=["question", "raw_output"],
template="""
You are a helpful data analyst assistant. Summarize the answer to the question below in a concise, quantified, and markdown-friendly way.
**Question:** {question}
**Agent Raw Output:** {raw_output}
Final response format:
- Quantified values wherever applicable
- Reasoning used
- 4–6 lines max
- Markdown-friendly format
"""
)
refine_chain = LLMChain(llm=llm, prompt=refine_prompt)
final_response = refine_chain.run({
"question": user_input,
"raw_output": raw_output
})
history.append((user_input, final_response))
return history, history
# ─────────────────────────────────────────────
# πŸŽ› Gradio App
# ─────────────────────────────────────────────
with gr.Blocks(title="CSV Chat Assistant") as app:
gr.Markdown("## 🧠 Talk to Your CSV in Natural Language")
with gr.Row():
file_upload = gr.File(file_types=[".csv"], file_count="multiple", label="Upload CSV(s)")
upload_output = gr.Textbox(label="Upload Status")
dataset_radio = gr.Radio(choices=[], label="Select Dataset", interactive=True)
file_upload.change(fn=upload_csv, inputs=[file_upload], outputs=[upload_output, dataset_radio])
with gr.Row():
summary_button = gr.Button("🧾 Generate Summary")
summary_output = gr.Markdown()
summary_button.click(fn=generate_summary, inputs=dataset_radio, outputs=summary_output)
chatbot = gr.Chatbot(label="CSV Chat", height=400)
msg = gr.Textbox(label="Type your question")
clear = gr.Button("Clear Chat")
msg.submit(fn=ask_csv, inputs=[dataset_radio, msg, chatbot], outputs=[chatbot, chatbot])
clear.click(lambda: [], None, chatbot)
app.launch()