Spaces:
Configuration error
Configuration error
| import os | |
| from typing import List, Optional, Tuple, Dict, Any | |
| import base64 | |
| import io | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from litellm import completion | |
| class DataAnalyzer: | |
| """Handles data analysis and visualization""" | |
| def __init__(self): | |
| self.data: Optional[pd.DataFrame] = None | |
| self.width = 800 | |
| self.height = 500 | |
| def create_histogram(self, column: str, bins: int = 30, title: str = "") -> str: | |
| """Create histogram with Plotly""" | |
| if self.data is None: | |
| raise ValueError("No data loaded") | |
| fig = go.Figure() | |
| fig.add_trace(go.Histogram( | |
| x=self.data[column], | |
| nbinsx=bins, | |
| name=column | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title=column, | |
| yaxis_title="Count", | |
| width=self.width, | |
| height=self.height, | |
| template="plotly_white" | |
| ) | |
| return fig.to_html(include_plotlyjs=True, full_html=False) | |
| def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None, | |
| title: str = "") -> str: | |
| """Create scatter plot with Plotly""" | |
| if self.data is None: | |
| raise ValueError("No data loaded") | |
| fig = go.Figure() | |
| if color_col: | |
| for category in self.data[color_col].unique(): | |
| mask = self.data[color_col] == category | |
| fig.add_trace(go.Scatter( | |
| x=self.data[mask][x_col], | |
| y=self.data[mask][y_col], | |
| mode='markers', | |
| name=str(category), | |
| text=self.data[mask][color_col] | |
| )) | |
| else: | |
| fig.add_trace(go.Scatter( | |
| x=self.data[x_col], | |
| y=self.data[y_col], | |
| mode='markers' | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title=x_col, | |
| yaxis_title=y_col, | |
| width=self.width, | |
| height=self.height, | |
| template="plotly_white", | |
| hovermode='closest' | |
| ) | |
| return fig.to_html(include_plotlyjs=True, full_html=False) | |
| def create_box(self, x_col: str, y_col: str, title: str = "") -> str: | |
| """Create box plot with Plotly""" | |
| if self.data is None: | |
| raise ValueError("No data loaded") | |
| fig = go.Figure() | |
| for category in self.data[x_col].unique(): | |
| fig.add_trace(go.Box( | |
| y=self.data[self.data[x_col] == category][y_col], | |
| name=str(category), | |
| boxpoints='all', | |
| jitter=0.3, | |
| pointpos=-1.8 | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| yaxis_title=y_col, | |
| xaxis_title=x_col, | |
| width=self.width, | |
| height=self.height, | |
| template="plotly_white", | |
| showlegend=False | |
| ) | |
| return fig.to_html(include_plotlyjs=True, full_html=False) | |
| def create_line(self, x_col: str, y_col: str, color_col: Optional[str] = None, | |
| title: str = "") -> str: | |
| """Create line plot with Plotly""" | |
| if self.data is None: | |
| raise ValueError("No data loaded") | |
| fig = go.Figure() | |
| if color_col: | |
| for category in self.data[color_col].unique(): | |
| mask = self.data[color_col] == category | |
| fig.add_trace(go.Scatter( | |
| x=self.data[mask][x_col], | |
| y=self.data[mask][y_col], | |
| mode='lines+markers', | |
| name=str(category) | |
| )) | |
| else: | |
| fig.add_trace(go.Scatter( | |
| x=self.data[x_col], | |
| y=self.data[y_col], | |
| mode='lines+markers' | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title=x_col, | |
| yaxis_title=y_col, | |
| width=self.width, | |
| height=self.height, | |
| template="plotly_white", | |
| hovermode='x unified' | |
| ) | |
| return fig.to_html(include_plotlyjs=True, full_html=False) | |
| class ChatAnalyzer: | |
| """Handles chat-based analysis with visualization""" | |
| def __init__(self): | |
| self.analyzer = DataAnalyzer() | |
| self.history: List[Tuple[str, str]] = [] | |
| def process_file(self, file: gr.File) -> List[Tuple[str, str]]: | |
| """Process uploaded file and initialize analyzer""" | |
| try: | |
| if file.name.endswith('.csv'): | |
| self.analyzer.data = pd.read_csv(file.name) | |
| elif file.name.endswith(('.xlsx', '.xls')): | |
| self.analyzer.data = pd.read_excel(file.name) | |
| else: | |
| return [("System", "Error: Please upload a CSV or Excel file.")] | |
| # Convert date columns to datetime | |
| date_cols = self.analyzer.data.select_dtypes(include=['object']).columns | |
| for col in date_cols: | |
| try: | |
| self.analyzer.data[col] = pd.to_datetime(self.analyzer.data[col]) | |
| except: | |
| continue | |
| info = f"""Data loaded successfully! | |
| Shape: {self.analyzer.data.shape} | |
| Columns: {', '.join(self.analyzer.data.columns)} | |
| Numeric columns: {', '.join(self.analyzer.data.select_dtypes(include=[np.number]).columns)} | |
| Date columns: {', '.join(self.analyzer.data.select_dtypes(include=['datetime64']).columns)} | |
| Categorical columns: {', '.join(self.analyzer.data.select_dtypes(include=['object']).columns)} | |
| """ | |
| self.history = [("System", info)] | |
| return self.history | |
| except Exception as e: | |
| self.history = [("System", f"Error loading file: {str(e)}")] | |
| return self.history | |
| def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]: | |
| """Process chat message and generate visualizations""" | |
| if self.analyzer.data is None: | |
| return [(message, "Please upload a data file first.")], "" | |
| if not api_key: | |
| return [(message, "Please provide an OpenAI API key.")], "" | |
| try: | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| # Get data context | |
| context = self._get_data_context() | |
| # Get AI response | |
| completion_response = completion( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": self._get_system_prompt()}, | |
| {"role": "user", "content": f"{context}\n\nUser question: {message}"} | |
| ], | |
| temperature=0.7 | |
| ) | |
| analysis = completion_response.choices[0].message.content | |
| # Create visualizations | |
| plots_html = "" | |
| try: | |
| # Extract code blocks | |
| import re | |
| code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL) | |
| for code in code_blocks: | |
| # Create namespace for execution | |
| namespace = { | |
| 'analyzer': self.analyzer, | |
| 'df': self.analyzer.data, | |
| 'print': lambda x: x | |
| } | |
| # Execute the code | |
| try: | |
| result = eval(code, namespace) | |
| if isinstance(result, str) and ('<div' in result or '<script' in result): | |
| plots_html += f'<div class="plot-container">{result}</div>' | |
| except: | |
| exec(code, namespace) | |
| except Exception as e: | |
| analysis += f"\n\nError creating visualization: {str(e)}" | |
| # Update chat history | |
| self.history.append((message, analysis)) | |
| return self.history, plots_html | |
| except Exception as e: | |
| self.history.append((message, f"Error: {str(e)}")) | |
| return self.history, "" | |
| def _get_data_context(self) -> str: | |
| """Get current data context for AI""" | |
| df = self.analyzer.data | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| date_cols = df.select_dtypes(include=['datetime64']).columns | |
| categorical_cols = df.select_dtypes(include=['object']).columns | |
| # Get basic statistics | |
| stats = df[numeric_cols].describe().to_string() if len(numeric_cols) > 0 else "No numeric columns" | |
| return f""" | |
| Data Information: | |
| - Shape: {df.shape} | |
| - Numeric columns: {', '.join(numeric_cols)} | |
| - Date columns: {', '.join(date_cols)} | |
| - Categorical columns: {', '.join(categorical_cols)} | |
| Basic Statistics: | |
| {stats} | |
| Available visualization functions: | |
| - analyzer.create_histogram(column, bins, title) | |
| - analyzer.create_scatter(x_col, y_col, color_col, title) | |
| - analyzer.create_box(x_col, y_col, title) | |
| - analyzer.create_line(x_col, y_col, color_col, title) | |
| """ | |
| def _get_system_prompt(self) -> str: | |
| """Get system prompt for AI""" | |
| return """You are a data analysis assistant specialized in creating interactive visualizations. | |
| Available visualization functions: | |
| 1. create_histogram(column, bins, title) - For distribution analysis | |
| 2. create_scatter(x_col, y_col, color_col, title) - For relationship analysis | |
| 3. create_box(x_col, y_col, title) - For categorical comparisons | |
| 4. create_line(x_col, y_col, color_col, title) - For trend analysis | |
| Example usage: | |
| ```python | |
| # Create histogram | |
| result = analyzer.create_histogram( | |
| column='Salary', | |
| bins=20, | |
| title='Salary Distribution' | |
| ) | |
| print(result) | |
| # Create scatter plot with time series | |
| result = analyzer.create_scatter( | |
| x_col='Date', | |
| y_col='Salary', | |
| color_col='Title', | |
| title='Salary Trends by Title' | |
| ) | |
| print(result) | |
| # Create box plot | |
| result = analyzer.create_box( | |
| x_col='Title', | |
| y_col='Salary', | |
| title='Salary Distribution by Title' | |
| ) | |
| print(result) | |
| ``` | |
| Always wrap code in Python code blocks and use print() to display the visualizations. | |
| Provide analysis and insights about what the visualizations show.""" | |
| def create_interface(): | |
| """Create Gradio interface""" | |
| analyzer = ChatAnalyzer() | |
| # Custom CSS | |
| css = """ | |
| .container { max-width: 1200px; margin: auto; } | |
| .plot-container { | |
| margin: 20px 0; | |
| padding: 15px; | |
| border: 1px solid #e0e0e0; | |
| border-radius: 8px; | |
| background: white; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .chat-message { | |
| margin-bottom: 15px; | |
| padding: 10px; | |
| border-radius: 8px; | |
| background: #f8f9fa; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(""" | |
| # Interactive Data Analysis Chat | |
| Upload your data and chat with AI to analyze it! Features: | |
| - Interactive visualizations | |
| - Natural language analysis | |
| - Statistical insights | |
| - Trend detection | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file = gr.File( | |
| label="Upload Data (CSV or Excel)", | |
| file_types=[".csv", ".xlsx", ".xls"] | |
| ) | |
| api_key = gr.Textbox( | |
| label="OpenAI API Key", | |
| type="password", | |
| placeholder="Enter your API key" | |
| ) | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| height=400, | |
| elem_classes="chat-message" | |
| ) | |
| message = gr.Textbox( | |
| label="Ask about your data", | |
| placeholder="e.g., Show me trends in the data", | |
| lines=2 | |
| ) | |
| send = gr.Button("Send") | |
| # Plot output area | |
| plot_output = gr.HTML( | |
| label="Visualizations", | |
| elem_classes="plot-container" | |
| ) | |
| # Event handlers | |
| file.change( | |
| analyzer.process_file, | |
| inputs=[file], | |
| outputs=[chatbot] | |
| ) | |
| send.click( | |
| analyzer.chat, | |
| inputs=[message, api_key], | |
| outputs=[chatbot, plot_output] | |
| ) | |
| # Example queries | |
| gr.Examples( | |
| examples=[ | |
| ["Show me a histogram of salary distribution"], | |
| ["Create a scatter plot of salary trends over time"], | |
| ["Show me box plots of salaries by title"], | |
| ["Analyze the trends and patterns in the data"], | |
| ], | |
| inputs=message | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() |