Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| import io | |
| from langchain.llms import OpenAI | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.agents.agent_types import AgentType | |
| from langchain_experimental.agents import create_pandas_dataframe_agent | |
| # Load API key | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY) | |
| # Global data stores | |
| dataset_dict = {} | |
| agent_dict = {} | |
| def upload_csv(files): | |
| global dataset_dict | |
| messages = [] | |
| filenames = [] | |
| for file in files: | |
| try: | |
| fname = os.path.basename(file.name) | |
| df = pd.read_csv(file) | |
| dataset_dict[fname] = df | |
| messages.append(f"β {fname} uploaded. Shape: {df.shape}") | |
| filenames.append(fname) | |
| except Exception as e: | |
| messages.append(f"β {file.name} failed: {e}") | |
| return "\n".join(messages), gr.update(choices=filenames, value=filenames[0] if filenames else None) | |
| def generate_summary(dataset_name): | |
| df = dataset_dict[dataset_name] | |
| df_info = io.StringIO() | |
| df.info(buf=df_info) | |
| info_str = df_info.getvalue() | |
| value_counts_dict = { | |
| col: df[col].value_counts().head(2).to_dict() | |
| for col in df.select_dtypes(include='object').columns | |
| } | |
| value_counts_str = "\n".join( | |
| f"{col}: {', '.join([f'{k} ({v})' for k, v in counts.items()])}" | |
| for col, counts in value_counts_dict.items() | |
| ) | |
| summary_prompt = PromptTemplate( | |
| input_variables=["shape", "columns", "info", "describe", "value_counts"], | |
| template=""" | |
| You are a data scientist. Given the following dataset info, generate a clear and concise summary in markdown format using the following structure: | |
| **π§Ύ Overview:** 1β2 lines about the data | |
| **π Key Insights:** 3β5 bullet points highlighting trends or distributions | |
| **π Notable Stats:** 2β3 bullet points covering variance, skew, or standout numerical stats | |
| Dataset shape: {shape} | |
| Column names: {columns} | |
| Info: {info} | |
| Describe: {describe} | |
| Value counts: {value_counts} | |
| """ | |
| ) | |
| summary_chain = LLMChain(llm=llm, prompt=summary_prompt) | |
| return summary_chain.run({ | |
| "shape": str(df.shape), | |
| "columns": ", ".join(df.columns), | |
| "info": info_str, | |
| "describe": df.describe().to_string(), | |
| "value_counts": value_counts_str | |
| }) | |
| def ask_csv(dataset_name, user_input, history): | |
| df = dataset_dict[dataset_name] | |
| if dataset_name not in agent_dict: | |
| agent_dict[dataset_name] = create_pandas_dataframe_agent( | |
| llm=llm, | |
| df=df, | |
| agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=False, | |
| handle_parsing_errors=True, | |
| allow_dangerous_code=True | |
| ) | |
| agent = agent_dict[dataset_name] | |
| raw_output = agent.run(user_input) | |
| refine_prompt = PromptTemplate( | |
| input_variables=["question", "raw_output"], | |
| template=""" | |
| You are a helpful data analyst assistant. Summarize the answer to the question below in a concise, quantified, and markdown-friendly way. | |
| **Question:** {question} | |
| **Agent Raw Output:** {raw_output} | |
| Final response format: | |
| - Quantified values wherever applicable | |
| - Reasoning used | |
| - 4β6 lines max | |
| - Markdown-friendly format | |
| """ | |
| ) | |
| refine_chain = LLMChain(llm=llm, prompt=refine_prompt) | |
| final_response = refine_chain.run({ | |
| "question": user_input, | |
| "raw_output": raw_output | |
| }) | |
| history.append((user_input, final_response)) | |
| return history, history | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # π Gradio App | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="CSV Chat Assistant") as app: | |
| gr.Markdown("## π§ Talk to Your CSV in Natural Language") | |
| with gr.Row(): | |
| file_upload = gr.File(file_types=[".csv"], file_count="multiple", label="Upload CSV(s)") | |
| upload_output = gr.Textbox(label="Upload Status") | |
| dataset_radio = gr.Radio(choices=[], label="Select Dataset", interactive=True) | |
| file_upload.change(fn=upload_csv, inputs=[file_upload], outputs=[upload_output, dataset_radio]) | |
| with gr.Row(): | |
| summary_button = gr.Button("π§Ύ Generate Summary") | |
| summary_output = gr.Markdown() | |
| summary_button.click(fn=generate_summary, inputs=dataset_radio, outputs=summary_output) | |
| chatbot = gr.Chatbot(label="CSV Chat", height=400) | |
| msg = gr.Textbox(label="Type your question") | |
| clear = gr.Button("Clear Chat") | |
| msg.submit(fn=ask_csv, inputs=[dataset_radio, msg, chatbot], outputs=[chatbot, chatbot]) | |
| clear.click(lambda: [], None, chatbot) | |
| app.launch() | |