jzou19950715's picture
Update app.py
9bb3afe verified
raw
history blame
13.5 kB
import os
from typing import List, Optional, Tuple, Dict, Any
import base64
import io
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from litellm import completion
class DataAnalyzer:
"""Handles data analysis and visualization"""
def __init__(self):
self.data: Optional[pd.DataFrame] = None
self.width = 800
self.height = 500
def create_histogram(self, column: str, bins: int = 30, title: str = "") -> str:
"""Create histogram with Plotly"""
if self.data is None:
raise ValueError("No data loaded")
fig = go.Figure()
fig.add_trace(go.Histogram(
x=self.data[column],
nbinsx=bins,
name=column
))
fig.update_layout(
title=title,
xaxis_title=column,
yaxis_title="Count",
width=self.width,
height=self.height,
template="plotly_white"
)
return fig.to_html(include_plotlyjs=True, full_html=False)
def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None,
title: str = "") -> str:
"""Create scatter plot with Plotly"""
if self.data is None:
raise ValueError("No data loaded")
fig = go.Figure()
if color_col:
for category in self.data[color_col].unique():
mask = self.data[color_col] == category
fig.add_trace(go.Scatter(
x=self.data[mask][x_col],
y=self.data[mask][y_col],
mode='markers',
name=str(category),
text=self.data[mask][color_col]
))
else:
fig.add_trace(go.Scatter(
x=self.data[x_col],
y=self.data[y_col],
mode='markers'
))
fig.update_layout(
title=title,
xaxis_title=x_col,
yaxis_title=y_col,
width=self.width,
height=self.height,
template="plotly_white",
hovermode='closest'
)
return fig.to_html(include_plotlyjs=True, full_html=False)
def create_box(self, x_col: str, y_col: str, title: str = "") -> str:
"""Create box plot with Plotly"""
if self.data is None:
raise ValueError("No data loaded")
fig = go.Figure()
for category in self.data[x_col].unique():
fig.add_trace(go.Box(
y=self.data[self.data[x_col] == category][y_col],
name=str(category),
boxpoints='all',
jitter=0.3,
pointpos=-1.8
))
fig.update_layout(
title=title,
yaxis_title=y_col,
xaxis_title=x_col,
width=self.width,
height=self.height,
template="plotly_white",
showlegend=False
)
return fig.to_html(include_plotlyjs=True, full_html=False)
def create_line(self, x_col: str, y_col: str, color_col: Optional[str] = None,
title: str = "") -> str:
"""Create line plot with Plotly"""
if self.data is None:
raise ValueError("No data loaded")
fig = go.Figure()
if color_col:
for category in self.data[color_col].unique():
mask = self.data[color_col] == category
fig.add_trace(go.Scatter(
x=self.data[mask][x_col],
y=self.data[mask][y_col],
mode='lines+markers',
name=str(category)
))
else:
fig.add_trace(go.Scatter(
x=self.data[x_col],
y=self.data[y_col],
mode='lines+markers'
))
fig.update_layout(
title=title,
xaxis_title=x_col,
yaxis_title=y_col,
width=self.width,
height=self.height,
template="plotly_white",
hovermode='x unified'
)
return fig.to_html(include_plotlyjs=True, full_html=False)
class ChatAnalyzer:
"""Handles chat-based analysis with visualization"""
def __init__(self):
self.analyzer = DataAnalyzer()
self.history: List[Tuple[str, str]] = []
def process_file(self, file: gr.File) -> List[Tuple[str, str]]:
"""Process uploaded file and initialize analyzer"""
try:
if file.name.endswith('.csv'):
self.analyzer.data = pd.read_csv(file.name)
elif file.name.endswith(('.xlsx', '.xls')):
self.analyzer.data = pd.read_excel(file.name)
else:
return [("System", "Error: Please upload a CSV or Excel file.")]
# Convert date columns to datetime
date_cols = self.analyzer.data.select_dtypes(include=['object']).columns
for col in date_cols:
try:
self.analyzer.data[col] = pd.to_datetime(self.analyzer.data[col])
except:
continue
info = f"""Data loaded successfully!
Shape: {self.analyzer.data.shape}
Columns: {', '.join(self.analyzer.data.columns)}
Numeric columns: {', '.join(self.analyzer.data.select_dtypes(include=[np.number]).columns)}
Date columns: {', '.join(self.analyzer.data.select_dtypes(include=['datetime64']).columns)}
Categorical columns: {', '.join(self.analyzer.data.select_dtypes(include=['object']).columns)}
"""
self.history = [("System", info)]
return self.history
except Exception as e:
self.history = [("System", f"Error loading file: {str(e)}")]
return self.history
def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]:
"""Process chat message and generate visualizations"""
if self.analyzer.data is None:
return [(message, "Please upload a data file first.")], ""
if not api_key:
return [(message, "Please provide an OpenAI API key.")], ""
try:
os.environ["OPENAI_API_KEY"] = api_key
# Get data context
context = self._get_data_context()
# Get AI response
completion_response = completion(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": self._get_system_prompt()},
{"role": "user", "content": f"{context}\n\nUser question: {message}"}
],
temperature=0.7
)
analysis = completion_response.choices[0].message.content
# Create visualizations
plots_html = ""
try:
# Extract code blocks
import re
code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
for code in code_blocks:
# Create namespace for execution
namespace = {
'analyzer': self.analyzer,
'df': self.analyzer.data,
'print': lambda x: x
}
# Execute the code
try:
result = eval(code, namespace)
if isinstance(result, str) and ('<div' in result or '<script' in result):
plots_html += f'<div class="plot-container">{result}</div>'
except:
exec(code, namespace)
except Exception as e:
analysis += f"\n\nError creating visualization: {str(e)}"
# Update chat history
self.history.append((message, analysis))
return self.history, plots_html
except Exception as e:
self.history.append((message, f"Error: {str(e)}"))
return self.history, ""
def _get_data_context(self) -> str:
"""Get current data context for AI"""
df = self.analyzer.data
numeric_cols = df.select_dtypes(include=[np.number]).columns
date_cols = df.select_dtypes(include=['datetime64']).columns
categorical_cols = df.select_dtypes(include=['object']).columns
# Get basic statistics
stats = df[numeric_cols].describe().to_string() if len(numeric_cols) > 0 else "No numeric columns"
return f"""
Data Information:
- Shape: {df.shape}
- Numeric columns: {', '.join(numeric_cols)}
- Date columns: {', '.join(date_cols)}
- Categorical columns: {', '.join(categorical_cols)}
Basic Statistics:
{stats}
Available visualization functions:
- analyzer.create_histogram(column, bins, title)
- analyzer.create_scatter(x_col, y_col, color_col, title)
- analyzer.create_box(x_col, y_col, title)
- analyzer.create_line(x_col, y_col, color_col, title)
"""
def _get_system_prompt(self) -> str:
"""Get system prompt for AI"""
return """You are a data analysis assistant specialized in creating interactive visualizations.
Available visualization functions:
1. create_histogram(column, bins, title) - For distribution analysis
2. create_scatter(x_col, y_col, color_col, title) - For relationship analysis
3. create_box(x_col, y_col, title) - For categorical comparisons
4. create_line(x_col, y_col, color_col, title) - For trend analysis
Example usage:
```python
# Create histogram
result = analyzer.create_histogram(
column='Salary',
bins=20,
title='Salary Distribution'
)
print(result)
# Create scatter plot with time series
result = analyzer.create_scatter(
x_col='Date',
y_col='Salary',
color_col='Title',
title='Salary Trends by Title'
)
print(result)
# Create box plot
result = analyzer.create_box(
x_col='Title',
y_col='Salary',
title='Salary Distribution by Title'
)
print(result)
```
Always wrap code in Python code blocks and use print() to display the visualizations.
Provide analysis and insights about what the visualizations show."""
def create_interface():
"""Create Gradio interface"""
analyzer = ChatAnalyzer()
# Custom CSS
css = """
.container { max-width: 1200px; margin: auto; }
.plot-container {
margin: 20px 0;
padding: 15px;
border: 1px solid #e0e0e0;
border-radius: 8px;
background: white;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.chat-message {
margin-bottom: 15px;
padding: 10px;
border-radius: 8px;
background: #f8f9fa;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("""
# Interactive Data Analysis Chat
Upload your data and chat with AI to analyze it! Features:
- Interactive visualizations
- Natural language analysis
- Statistical insights
- Trend detection
""")
with gr.Row():
with gr.Column(scale=1):
file = gr.File(
label="Upload Data (CSV or Excel)",
file_types=[".csv", ".xlsx", ".xls"]
)
api_key = gr.Textbox(
label="OpenAI API Key",
type="password",
placeholder="Enter your API key"
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=400,
elem_classes="chat-message"
)
message = gr.Textbox(
label="Ask about your data",
placeholder="e.g., Show me trends in the data",
lines=2
)
send = gr.Button("Send")
# Plot output area
plot_output = gr.HTML(
label="Visualizations",
elem_classes="plot-container"
)
# Event handlers
file.change(
analyzer.process_file,
inputs=[file],
outputs=[chatbot]
)
send.click(
analyzer.chat,
inputs=[message, api_key],
outputs=[chatbot, plot_output]
)
# Example queries
gr.Examples(
examples=[
["Show me a histogram of salary distribution"],
["Create a scatter plot of salary trends over time"],
["Show me box plots of salaries by title"],
["Analyze the trends and patterns in the data"],
],
inputs=message
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()