jzou19950715's picture
Update app.py
4e06409 verified
raw
history blame
13 kB
"""
Advanced Data Analysis Assistant with Interactive Visualizations
Integrates smolagents, GPT-4, and interactive Plotly visualizations.
"""
import base64
import io
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, Tuple
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
from smolagents import CodeAgent, LiteLLMModel, tool
from datetime import datetime, timedelta
# Constants
SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
DEFAULT_MODEL = "gpt-4o-mini"
HISTORY_FILE = "analysis_history.json"
@dataclass
class VisualizationConfig:
"""Configuration for visualizations."""
width: int = 800
height: int = 500
template: str = "plotly_white"
show_grid: bool = True
interactive: bool = True
class DataPreprocessor:
"""Handles data preprocessing and validation."""
@staticmethod
def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
"""Preprocess the dataframe and return metadata."""
metadata = {
"original_shape": df.shape,
"missing_values": df.isnull().sum().to_dict(),
"dtypes": df.dtypes.astype(str).to_dict(),
"numeric_columns": df.select_dtypes(include=[np.number]).columns.tolist(),
"categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
"temporal_columns": []
}
# Handle date/time columns
for col in df.columns:
try:
pd.to_datetime(df[col])
metadata["temporal_columns"].append(col)
df[col] = pd.to_datetime(df[col])
except:
continue
# Handle missing values
df = df.fillna(method='ffill').fillna(method='bfill')
return df, metadata
class CodeExecutionEnvironment:
"""Safe environment for executing analysis code."""
def __init__(self, visualization_config: Optional[VisualizationConfig] = None):
self.viz_config = visualization_config or VisualizationConfig()
self.globals = {
'pd': pd,
'np': np,
'px': px,
'go': go,
'make_subplots': make_subplots,
'sns': sns
}
self.locals = {}
def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]:
"""Execute code and capture outputs including visualizations."""
if df is not None:
self.globals['df'] = df
output_buffer = io.StringIO()
import sys
sys.stdout = output_buffer
result = {
'output': '',
'plotly_html': [],
'error': None,
'dataframe_updates': None
}
try:
exec(code, self.globals, self.locals)
# Capture Plotly figures
for var_name, value in self.locals.items():
if isinstance(value, (go.Figure, px.Figure)):
# Apply visualization config
value.update_layout(
width=self.viz_config.width,
height=self.viz_config.height,
template=self.viz_config.template,
showgrid=self.viz_config.show_grid
)
html = value.to_html(
include_plotlyjs=True,
full_html=False,
config={'displayModeBar': True}
)
result['plotly_html'].append(html)
# Capture DataFrame updates
if 'df' in self.locals and id(self.locals['df']) != id(df):
result['dataframe_updates'] = self.locals['df']
result['output'] = output_buffer.getvalue()
except Exception as e:
result['error'] = f"Error executing code: {str(e)}"
finally:
sys.stdout = sys.__stdout__
output_buffer.close()
return result
class AnalysisHistory:
"""Manages analysis history and persistence."""
def __init__(self, history_file: str = HISTORY_FILE):
self.history_file = history_file
self.history = self._load_history()
def _load_history(self) -> List[Dict]:
if os.path.exists(self.history_file):
try:
with open(self.history_file, 'r') as f:
return json.load(f)
except:
return []
return []
def add_entry(self, query: str, result: str) -> None:
"""Add new analysis entry to history."""
entry = {
'timestamp': datetime.now().isoformat(),
'query': query,
'result': result
}
self.history.append(entry)
with open(self.history_file, 'w') as f:
json.dump(self.history, f)
def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
"""Get recent analysis entries."""
return sorted(
self.history,
key=lambda x: x['timestamp'],
reverse=True
)[:limit]
class DataAnalysisAssistant:
"""Enhanced data analysis assistant with visualization capabilities."""
def __init__(self, api_key: str):
self.model = LiteLLMModel(
model_id=DEFAULT_MODEL,
api_key=api_key
)
self.code_env = CodeExecutionEnvironment()
self.history = AnalysisHistory()
# Initialize agent with tools
self.agent = CodeAgent(
model=self.model,
additional_authorized_imports=[
'pandas', 'numpy', 'plotly.express', 'plotly.graph_objects',
'seaborn', 'scipy', 'statsmodels'
],
)
def analyze(self, df: pd.DataFrame, query: str) -> str:
"""Perform analysis with interactive visualizations."""
# Preprocess data
df, metadata = DataPreprocessor.preprocess_dataframe(df)
# Create context for the agent
context = self._create_analysis_context(df, metadata, query)
try:
# Get analysis plan
response = self.agent.run(context, additional_args={"df": df})
# Extract and execute code blocks
results = self._execute_analysis(response, df)
# Save to history
self.history.add_entry(query, str(response))
return self._format_results(response, results)
except Exception as e:
return f"Analysis failed: {str(e)}"
def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
"""Create detailed context for analysis."""
return f"""
Analyze the following data with interactive visualizations.
DataFrame Information:
- Shape: {metadata['original_shape']}
- Numeric columns: {', '.join(metadata['numeric_columns'])}
- Categorical columns: {', '.join(metadata['categorical_columns'])}
- Temporal columns: {', '.join(metadata['temporal_columns'])}
User Query: {query}
Guidelines:
1. Use Plotly for interactive visualizations
2. Store figures in variables named 'fig'
3. Include clear titles and labels
4. Add hover information
5. Use color effectively
6. Handle errors gracefully
The DataFrame is available as 'df'.
"""
def _execute_analysis(self, response: str, df: pd.DataFrame) -> List[Dict]:
"""Execute code blocks from analysis."""
import re
results = []
# Extract code blocks
code_blocks = re.findall(r'```python\n(.*?)```', str(response), re.DOTALL)
for code in code_blocks:
result = self.code_env.execute(code, df)
results.append(result)
return results
def _format_results(self, response: str, results: List[Dict]) -> str:
"""Format analysis results with visualizations."""
output_parts = []
# Add analysis text
analysis_text = str(response).replace("```python", "").replace("```", "")
output_parts.append(f'<div class="analysis-text">{analysis_text}</div>')
# Add execution results
for result in results:
if result['error']:
output_parts.append(f'<div class="error">{result["error"]}</div>')
else:
if result['output']:
output_parts.append(f'<pre>{result["output"]}</pre>')
for html in result['plotly_html']:
output_parts.append(
f'<div class="plot-container">{html}</div>'
)
return "\n".join(output_parts)
def process_file(file: gr.File) -> Optional[pd.DataFrame]:
"""Process uploaded file into DataFrame."""
if not file:
return None
try:
file_path = Path(file.name)
if file_path.suffix == '.csv':
return pd.read_csv(file_path)
elif file_path.suffix in ('.xlsx', '.xls'):
return pd.read_excel(file_path)
else:
raise ValueError(f"Unsupported file type: {file_path.suffix}")
except Exception as e:
raise RuntimeError(f"Error reading file: {str(e)}")
def analyze_data(
file: gr.File,
query: str,
api_key: str,
) -> str:
"""Main analysis function for Gradio interface."""
if not api_key:
return "Error: Please provide an API key"
if not file:
return "Error: Please upload a data file"
try:
# Process file
df = process_file(file)
if df is None:
return "Error: Could not process file"
# Create assistant and run analysis
assistant = DataAnalysisAssistant(api_key)
return assistant.analyze(df, query)
except Exception as e:
return f"Error: {str(e)}"
def create_interface():
"""Create enhanced Gradio interface."""
css = """
.plot-container {
margin: 20px 0;
padding: 15px;
border: 1px solid #e0e0e0;
border-radius: 8px;
background: white;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.analysis-text {
margin: 20px 0;
line-height: 1.6;
}
.error {
color: red;
padding: 10px;
margin: 10px 0;
border-left: 4px solid red;
}
"""
with gr.Blocks(css=css) as interface:
gr.Markdown("""
# Advanced Data Analysis Assistant
Upload your data and get AI-powered analysis with interactive visualizations.
**Features:**
- Interactive Plotly visualizations
- gpt-4o-mini powered analysis
- Time series analysis
- Statistical insights
- Natural language queries
**Required:** OpenAI API key
""")
with gr.Row():
with gr.Column():
file = gr.File(
label="Upload Data File",
file_types=SUPPORTED_FILE_TYPES
)
query = gr.Textbox(
label="What would you like to analyze?",
placeholder="e.g., Analyze trends and patterns in the data with interactive visualizations",
lines=3
)
api_key = gr.Textbox(
label="OpenAI API Key",
placeholder="Your API key",
type="password"
)
analyze_btn = gr.Button("Analyze")
with gr.Column():
output = gr.HTML(label="Analysis Results")
analyze_btn.click(
analyze_data,
inputs=[file, query, api_key],
outputs=output
)
# Add examples
gr.Examples(
examples=[
[None, "Show trends over time with interactive visualizations"],
[None, "Create a comprehensive analysis of relationships between variables"],
[None, "Analyze distributions and statistical patterns"],
[None, "Generate financial metrics and performance indicators"],
],
inputs=[file, query]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()