| import os |
| from typing import Optional |
|
|
| import gradio as gr |
| import pandas as pd |
| from smolagents import CodeAgent, LiteLLMModel, tool |
|
|
|
|
| |
| @tool |
| def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str: |
| """Analyze a pandas DataFrame""" |
| if analysis_type == "summary": |
| return str(df.describe()) |
| elif analysis_type == "info": |
| return str(df.info()) |
| return "Unknown analysis type" |
|
|
| @tool |
| def plot_data(df: pd.DataFrame, plot_type: str) -> None: |
| """Create plots from DataFrame""" |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| |
| if plot_type == "correlation": |
| plt.figure(figsize=(10, 8)) |
| sns.heatmap(df.corr(), annot=True) |
| plt.title("Correlation Heatmap") |
| elif plot_type == "distribution": |
| df.hist(figsize=(15, 10)) |
| plt.tight_layout() |
|
|
| def process_file(file: gr.File) -> Optional[pd.DataFrame]: |
| """Process uploaded file into a DataFrame""" |
| if not file: |
| return None |
| |
| try: |
| if file.name.endswith('.csv'): |
| df = pd.read_csv(file.name) |
| elif file.name.endswith(('.xlsx', '.xls')): |
| df = pd.read_excel(file.name) |
| else: |
| return None |
| return df |
| except Exception as e: |
| print(f"Error reading {file.name}: {str(e)}") |
| return None |
|
|
| def analyze_data( |
| file: gr.File, |
| query: str, |
| api_key: str, |
| temperature: float = 0.7, |
| ) -> str: |
| """Process user request and generate analysis using smolagents""" |
| |
| if not api_key: |
| return "Error: Please provide an API key." |
| |
| if not file: |
| return "Error: Please upload a file." |
| |
| try: |
| |
| os.environ["OPENAI_API_KEY"] = api_key |
| |
| |
| model = LiteLLMModel( |
| model_id="gpt-4o-mini", |
| temperature=temperature |
| ) |
| |
| |
| agent = CodeAgent( |
| tools=[analyze_dataframe, plot_data], |
| model=model, |
| additional_authorized_imports=[ |
| "pandas", |
| "numpy", |
| "matplotlib", |
| "seaborn", |
| "plotly", |
| "sklearn", |
| "scipy" |
| ], |
| max_steps=5, |
| verbosity_level=1 |
| ) |
| |
| |
| df = process_file(file) |
| if df is None: |
| return "Error: Could not process uploaded file." |
| |
| |
| file_info = f""" |
| Uploaded file: {file.name} |
| DataFrame Shape: {df.shape} |
| Columns: {', '.join(df.columns)} |
| Column Types: |
| {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])} |
| """ |
| |
| |
| prompt = f""" |
| {file_info} |
| |
| The data has been loaded into a pandas DataFrame called 'df'. |
| Available tools: |
| - analyze_dataframe: Perform basic DataFrame analysis |
| - plot_data: Create various plots |
| |
| Additional capabilities: |
| - Full pandas, numpy, matplotlib, seaborn access |
| - Machine learning with sklearn |
| - Statistical analysis with scipy |
| |
| User request: {query} |
| |
| Please analyze the data and provide: |
| 1. A clear explanation of your approach |
| 2. Code for the analysis |
| 3. Visualizations where relevant |
| 4. Key insights and findings |
| """ |
| |
| |
| result = agent.run(prompt, additional_args={"df": df}) |
| return result |
| |
| except Exception as e: |
| return f"Error occurred: {str(e)}" |
|
|
| def create_interface(): |
| """Create Gradio interface""" |
| |
| with gr.Blocks(title="AI Agent Testing Interface") as interface: |
| gr.Markdown(""" |
| # AI Agent Testing Interface |
| |
| Test the capabilities of AI agents using smolagents library. Upload data files and ask questions in natural language. |
| |
| **Features:** |
| - Data analysis and visualization |
| - Machine learning capabilities |
| - Statistical analysis |
| - Custom tool integration |
| |
| **Note**: Requires your own API key for GPT-4. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| file = gr.File( |
| label="Upload Data File (CSV/Excel)", |
| file_types=[".csv", ".xlsx", ".xls"] |
| ) |
| query = gr.Textbox( |
| label="What would you like to analyze?", |
| placeholder="e.g., Analyze the relationships between variables and create visualizations", |
| lines=3 |
| ) |
| api_key = gr.Textbox( |
| label="API Key (Required)", |
| placeholder="Your API key", |
| type="password" |
| ) |
| temperature = gr.Slider( |
| label="Temperature", |
| minimum=0.0, |
| maximum=1.0, |
| value=0.7, |
| step=0.1 |
| ) |
| analyze_btn = gr.Button("Analyze") |
| |
| with gr.Column(): |
| output = gr.Markdown(label="Output") |
| |
| |
| analyze_btn.click( |
| analyze_data, |
| inputs=[file, query, api_key, temperature], |
| outputs=output |
| ) |
| |
| |
| gr.Examples( |
| examples=[ |
| [None, "Perform comprehensive exploratory data analysis including distributions, correlations, and key statistics"], |
| [None, "Create visualizations showing relationships between numeric variables"], |
| [None, "Identify and analyze outliers in the dataset"], |
| [None, "Perform clustering analysis and visualize the results"], |
| [None, "Calculate summary statistics and create box plots for numeric columns"], |
| ], |
| inputs=[file, query] |
| ) |
| |
| return interface |
|
|
| if __name__ == "__main__": |
| interface = create_interface() |
| interface.launch() |