| import base64 |
| import io |
| import os |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, List, Optional |
|
|
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| from litellm import completion |
|
|
|
|
| |
| class CodeEnvironment: |
| """Safe environment for executing code with data analysis capabilities""" |
| |
| def __init__(self): |
| self.globals = { |
| 'pd': pd, |
| 'np': np, |
| 'plt': plt, |
| 'sns': sns, |
| } |
| self.locals = {} |
| |
| def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]: |
| """Execute code and capture outputs""" |
| if df is not None: |
| self.globals['df'] = df |
| |
| |
| output_buffer = io.StringIO() |
| result = {'output': '', 'figures': [], 'error': None} |
| |
| try: |
| |
| exec(code, self.globals, self.locals) |
| |
| |
| for i in plt.get_fignums(): |
| fig = plt.figure(i) |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png') |
| buf.seek(0) |
| img_str = base64.b64encode(buf.read()).decode() |
| result['figures'].append(f"data:image/png;base64,{img_str}") |
| plt.close(fig) |
| |
| |
| result['output'] = output_buffer.getvalue() |
| |
| except Exception as e: |
| result['error'] = str(e) |
| |
| finally: |
| output_buffer.close() |
| |
| return result |
|
|
| @dataclass |
| class Tool: |
| """Tool for data analysis""" |
| name: str |
| description: str |
| func: Callable |
|
|
| class AnalysisAgent: |
| """Agent that can analyze data and execute code""" |
| |
| def __init__( |
| self, |
| model_id: str = "gpt-4o-mini", |
| temperature: float = 0.7, |
| ): |
| self.model_id = model_id |
| self.temperature = temperature |
| self.tools: List[Tool] = [] |
| self.code_env = CodeEnvironment() |
| |
| def add_tool(self, name: str, description: str, func: Callable) -> None: |
| """Add a tool to the agent""" |
| self.tools.append(Tool(name=name, description=description, func=func)) |
| |
| def run(self, prompt: str, df: pd.DataFrame = None) -> str: |
| """Run analysis with code execution""" |
| messages = [ |
| {"role": "system", "content": self._get_system_prompt()}, |
| {"role": "user", "content": prompt} |
| ] |
| |
| try: |
| |
| response = completion( |
| model=self.model_id, |
| messages=messages, |
| temperature=self.temperature, |
| ) |
| analysis = response.choices[0].message.content |
| |
| |
| code_blocks = self._extract_code(analysis) |
| |
| |
| results = [] |
| for code in code_blocks: |
| result = self.code_env.execute(code, df) |
| if result['error']: |
| results.append(f"Error executing code: {result['error']}") |
| else: |
| |
| if result['output']: |
| results.append(result['output']) |
| for fig in result['figures']: |
| results.append(f"") |
| |
| |
| return analysis + "\n\n" + "\n".join(results) |
| |
| except Exception as e: |
| return f"Error: {str(e)}" |
| |
| def _get_system_prompt(self) -> str: |
| """Get system prompt with tools and capabilities""" |
| tools_desc = "\n".join([ |
| f"- {tool.name}: {tool.description}" |
| for tool in self.tools |
| ]) |
| |
| return f"""You are a data analysis assistant. |
| |
| Available tools: |
| {tools_desc} |
| |
| Capabilities: |
| - Data analysis (pandas, numpy) |
| - Visualization (matplotlib, seaborn) |
| - Statistical analysis (scipy) |
| - Machine learning (sklearn) |
| |
| When writing code: |
| - Use markdown code blocks |
| - Create clear visualizations |
| - Include explanations |
| - Handle errors gracefully |
| """ |
| |
| @staticmethod |
| def _extract_code(text: str) -> List[str]: |
| """Extract Python code blocks from markdown""" |
| import re |
| pattern = r'```python\n(.*?)```' |
| return re.findall(pattern, text, re.DOTALL) |
|
|
| def process_file(file: gr.File) -> Optional[pd.DataFrame]: |
| """Process uploaded file into DataFrame""" |
| if not file: |
| return None |
| |
| try: |
| if file.name.endswith('.csv'): |
| return pd.read_csv(file.name) |
| elif file.name.endswith(('.xlsx', '.xls')): |
| return pd.read_excel(file.name) |
| except Exception as e: |
| print(f"Error reading file: {str(e)}") |
| return None |
|
|
| def analyze_data( |
| file: gr.File, |
| query: str, |
| api_key: str, |
| temperature: float = 0.7, |
| ) -> str: |
| """Process user request and generate analysis""" |
| |
| if not api_key: |
| return "Error: Please provide an API key." |
| |
| if not file: |
| return "Error: Please upload a file." |
| |
| try: |
| |
| os.environ["OPENAI_API_KEY"] = api_key |
| |
| |
| agent = AnalysisAgent( |
| model_id="gpt-4o-mini", |
| temperature=temperature |
| ) |
| |
| |
| df = process_file(file) |
| if df is None: |
| return "Error: Could not process file." |
| |
| |
| file_info = f""" |
| File: {file.name} |
| Shape: {df.shape} |
| Columns: {', '.join(df.columns)} |
| |
| Column Types: |
| {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])} |
| """ |
| |
| |
| prompt = f""" |
| {file_info} |
| |
| The data is loaded in a pandas DataFrame called 'df'. |
| |
| User request: {query} |
| |
| Please analyze the data and provide: |
| 1. Clear explanation of approach |
| 2. Code with visualizations |
| 3. Key insights and findings |
| """ |
| |
| return agent.run(prompt, df=df) |
| |
| except Exception as e: |
| return f"Error occurred: {str(e)}" |
|
|
| def create_interface(): |
| """Create Gradio interface""" |
| |
| with gr.Blocks(title="AI Data Analysis Assistant") as interface: |
| gr.Markdown(""" |
| # AI Data Analysis Assistant |
| |
| Upload your data file and get AI-powered analysis with visualizations. |
| |
| **Features:** |
| - Data analysis and visualization |
| - Statistical analysis |
| - Machine learning capabilities |
| |
| **Note**: Requires your own OpenAi API key. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| file = gr.File( |
| label="Upload Data File", |
| file_types=[".csv", ".xlsx", ".xls"] |
| ) |
| query = gr.Textbox( |
| label="What would you like to analyze?", |
| placeholder="e.g., Create visualizations showing relationships between variables", |
| lines=3 |
| ) |
| api_key = gr.Textbox( |
| label="API Key (Required)", |
| placeholder="Your API key", |
| type="password" |
| ) |
| temperature = gr.Slider( |
| label="Temperature", |
| minimum=0.0, |
| maximum=1.0, |
| value=0.7, |
| step=0.1 |
| ) |
| analyze_btn = gr.Button("Analyze") |
| |
| with gr.Column(): |
| output = gr.Markdown(label="Output") |
| |
| analyze_btn.click( |
| analyze_data, |
| inputs=[file, query, api_key, temperature], |
| outputs=output |
| ) |
| |
| gr.Examples( |
| examples=[ |
| [None, "Show the distribution of values and key statistics"], |
| [None, "Create a correlation analysis with heatmap"], |
| [None, "Identify and visualize any outliers in the data"], |
| [None, "Generate summary plots for the main variables"], |
| ], |
| inputs=[file, query] |
| ) |
| |
| return interface |
|
|
| if __name__ == "__main__": |
| interface = create_interface() |
| interface.launch() |