Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from io import StringIO | |
| import sys | |
| import traceback | |
| from contextlib import redirect_stdout, redirect_stderr | |
| import openai | |
| from openai import OpenAI | |
| import re | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Set page config | |
| st.set_page_config( | |
| page_title="CSV Chat Assistant", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 3rem; | |
| font-weight: bold; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .stAlert { | |
| margin-top: 1rem; | |
| } | |
| .code-output { | |
| background-color: #f0f2f6; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| border-left: 3px solid #667eea; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def execute_python_code(code, df): | |
| """Execute Python code safely and capture output""" | |
| # Create a copy of the dataframe to avoid modifying the original | |
| df_copy = df.copy() | |
| # Capture stdout and stderr | |
| old_stdout = sys.stdout | |
| old_stderr = sys.stderr | |
| stdout_capture = StringIO() | |
| stderr_capture = StringIO() | |
| # Global namespace for code execution | |
| globals_dict = { | |
| 'df': df_copy, | |
| 'pd': pd, | |
| 'np': np, | |
| 'plt': plt, | |
| 'sns': sns, | |
| 'px': px, | |
| 'go': go, | |
| 'st': st, | |
| '__builtins__': __builtins__ | |
| } | |
| try: | |
| sys.stdout = stdout_capture | |
| sys.stderr = stderr_capture | |
| # Execute the code | |
| exec(code, globals_dict) | |
| # Get the output | |
| output = stdout_capture.getvalue() | |
| error = stderr_capture.getvalue() | |
| return output, error, None, globals_dict.get('df', df_copy) | |
| except Exception as e: | |
| error = stderr_capture.getvalue() + f"\nError: {str(e)}\n{traceback.format_exc()}" | |
| return "", error, str(e), df_copy | |
| finally: | |
| sys.stdout = old_stdout | |
| sys.stderr = old_stderr | |
| def generate_python_code(user_query, df_info, api_key): | |
| """Generate Python code from natural language using OpenAI""" | |
| try: | |
| client = OpenAI(api_key=api_key) | |
| prompt = f""" | |
| You are a Python code generator for data analysis. Generate Python code based on the user's request. | |
| Dataset Information: | |
| - Columns: {list(df_info['columns'])} | |
| - Shape: {df_info['shape']} | |
| - Data types: {df_info['dtypes']} | |
| - Sample data (first few rows): {df_info['sample']} | |
| User Query: {user_query} | |
| Guidelines: | |
| 1. The dataframe is already loaded as 'df' | |
| 2. Use pandas, numpy, matplotlib, seaborn, or plotly as needed | |
| 3. For visualizations, use st.pyplot(plt.gcf()) for matplotlib/seaborn or st.plotly_chart() for plotly | |
| 4. Print results using print() statements | |
| 5. Keep code concise and focused on the user's request | |
| 6. If creating plots, make sure to show them in Streamlit | |
| 7. Handle missing values appropriately | |
| 8. Use descriptive variable names | |
| Return ONLY the Python code, no explanations or markdown formatting. | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful data analysis assistant that generates Python code."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.1, | |
| max_tokens=1000 | |
| ) | |
| code = response.choices[0].message.content.strip() | |
| # Clean up the code (remove markdown formatting if present) | |
| code = re.sub(r'```python\n?', '', code) | |
| code = re.sub(r'```\n?', '', code) | |
| return code | |
| except Exception as e: | |
| st.error(f"Error generating code: {str(e)}") | |
| return None | |
| def get_dataframe_info(df): | |
| """Get comprehensive information about the dataframe""" | |
| return { | |
| 'columns': df.columns.tolist(), | |
| 'shape': df.shape, | |
| 'dtypes': df.dtypes.to_dict(), | |
| 'sample': df.head(3).to_dict(), | |
| 'describe': df.describe().to_dict() if len(df.select_dtypes(include=[np.number]).columns) > 0 else {}, | |
| 'null_counts': df.isnull().sum().to_dict() | |
| } | |
| def main(): | |
| # Header | |
| st.markdown('<h1 class="main-header">π CSV Chat Assistant</h1>', unsafe_allow_html=True) | |
| st.markdown("Upload your CSV file and chat with your data using natural language!") | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("π§ Configuration") | |
| # API Key input | |
| st.subheader("π OpenAI API Key") | |
| api_key = st.text_input( | |
| "Enter your OpenAI API Key:", | |
| type="password", | |
| help="Get your API key from https://platform.openai.com/api-keys" | |
| ) | |
| if api_key: | |
| st.success("β API Key provided") | |
| else: | |
| st.warning("β οΈ Please enter your OpenAI API Key") | |
| st.divider() | |
| # File upload | |
| st.subheader("π Upload CSV File") | |
| uploaded_file = st.file_uploader( | |
| "Drag and drop your CSV file here:", | |
| type=['csv'], | |
| help="Upload a CSV file to start analyzing your data" | |
| ) | |
| if uploaded_file is not None: | |
| st.success(f"β File uploaded: {uploaded_file.name}") | |
| # Display file info | |
| file_details = { | |
| "Filename": uploaded_file.name, | |
| "File size": f"{uploaded_file.size} bytes" | |
| } | |
| st.json(file_details) | |
| # Main content | |
| if uploaded_file is not None and api_key: | |
| try: | |
| # Load the CSV file | |
| df = pd.read_csv(uploaded_file) | |
| # Store dataframe in session state | |
| st.session_state['df'] = df | |
| st.session_state['api_key'] = api_key | |
| # Display dataset overview | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader("π Dataset Overview") | |
| st.dataframe(df.head(), use_container_width=True) | |
| with col2: | |
| st.subheader("π Dataset Info") | |
| st.write(f"**Shape:** {df.shape}") | |
| st.write(f"**Columns:** {len(df.columns)}") | |
| st.write(f"**Rows:** {len(df)}") | |
| if len(df.select_dtypes(include=[np.number]).columns) > 0: | |
| st.write("**Numerical Columns:**") | |
| for col in df.select_dtypes(include=[np.number]).columns: | |
| st.write(f"- {col}") | |
| if len(df.select_dtypes(include=['object']).columns) > 0: | |
| st.write("**Text Columns:**") | |
| for col in df.select_dtypes(include=['object']).columns: | |
| st.write(f"- {col}") | |
| st.divider() | |
| # Chat interface | |
| st.subheader("π¬ Chat with Your Data") | |
| st.write("Ask questions about your data in natural language. Examples:") | |
| # Example queries | |
| examples = [ | |
| "Show me the first 10 rows", | |
| "What are the summary statistics?", | |
| "Create a histogram of [column_name]", | |
| "Show correlation between columns", | |
| "Find rows where [column] > [value]", | |
| "Create a scatter plot of X vs Y", | |
| "Group by [column] and show counts" | |
| ] | |
| cols = st.columns(3) | |
| for i, example in enumerate(examples): | |
| with cols[i % 3]: | |
| if st.button(example, key=f"example_{i}"): | |
| st.session_state['user_query'] = example | |
| # Chat input | |
| user_query = st.text_area( | |
| "Ask a question about your data:", | |
| value=st.session_state.get('user_query', ''), | |
| height=100, | |
| placeholder="e.g., 'Show me a bar chart of the top 10 values in the sales column'" | |
| ) | |
| col1, col2, col3 = st.columns([1, 1, 4]) | |
| with col1: | |
| if st.button("π Generate & Run", type="primary"): | |
| if user_query.strip(): | |
| with st.spinner("Generating Python code..."): | |
| df_info = get_dataframe_info(df) | |
| code = generate_python_code(user_query, df_info, api_key) | |
| if code: | |
| st.session_state['generated_code'] = code | |
| st.session_state['user_query'] = user_query | |
| st.rerun() | |
| with col2: | |
| if st.button("ποΈ Clear"): | |
| if 'generated_code' in st.session_state: | |
| del st.session_state['generated_code'] | |
| if 'user_query' in st.session_state: | |
| del st.session_state['user_query'] | |
| st.rerun() | |
| # Display generated code and results | |
| if 'generated_code' in st.session_state: | |
| st.divider() | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("π Generated Python Code") | |
| st.code(st.session_state['generated_code'], language='python') | |
| if st.button("π Copy Code"): | |
| st.write("Code copied to clipboard!") | |
| with col2: | |
| st.subheader("π― Results") | |
| # Execute the code | |
| with st.spinner("Executing code..."): | |
| output, error, exception, result_df = execute_python_code( | |
| st.session_state['generated_code'], | |
| df | |
| ) | |
| if exception: | |
| st.error(f"**Error occurred:**\n{error}") | |
| else: | |
| if output: | |
| st.markdown('<div class="code-output">', unsafe_allow_html=True) | |
| st.text(output) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| if error and not exception: | |
| st.warning(f"**Warnings:**\n{error}") | |
| # Show any plots that were generated | |
| if plt.get_fignums(): | |
| st.pyplot(plt.gcf()) | |
| plt.close() | |
| except Exception as e: | |
| st.error(f"Error loading CSV file: {str(e)}") | |
| st.write("Please make sure your CSV file is properly formatted.") | |
| elif uploaded_file is None: | |
| st.info("π Please upload a CSV file in the sidebar to get started.") | |
| elif not api_key: | |
| st.info("π Please enter your OpenAI API key in the sidebar to use the chat feature.") | |
| # Footer | |
| st.divider() | |
| st.markdown(""" | |
| <div style="text-align: center; color: #666;"> | |
| <p>Built with β€οΈ using Streamlit | Powered by OpenAI GPT-3.5</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |