jzou19950715's picture
Update app.py
2952e7a verified
raw
history blame
6.33 kB
import os
from typing import Optional
import gradio as gr
import pandas as pd
from smolagents import CodeAgent, LiteLLMModel, tool
# Tool definitions to showcase smolagents capabilities
@tool
def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
"""Analyze a pandas DataFrame"""
if analysis_type == "summary":
return str(df.describe())
elif analysis_type == "info":
return str(df.info())
return "Unknown analysis type"
@tool
def plot_data(df: pd.DataFrame, plot_type: str) -> None:
"""Create plots from DataFrame"""
import matplotlib.pyplot as plt
import seaborn as sns
if plot_type == "correlation":
plt.figure(figsize=(10, 8))
sns.heatmap(df.corr(), annot=True)
plt.title("Correlation Heatmap")
elif plot_type == "distribution":
df.hist(figsize=(15, 10))
plt.tight_layout()
def process_file(file: gr.File) -> Optional[pd.DataFrame]:
"""Process uploaded file into a DataFrame"""
if not file:
return None
try:
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
elif file.name.endswith(('.xlsx', '.xls')):
df = pd.read_excel(file.name)
else:
return None
return df
except Exception as e:
print(f"Error reading {file.name}: {str(e)}")
return None
def analyze_data(
file: gr.File,
query: str,
api_key: str,
temperature: float = 0.7,
) -> str:
"""Process user request and generate analysis using smolagents"""
if not api_key:
return "Error: Please provide an API key."
if not file:
return "Error: Please upload a file."
try:
# Set up environment
os.environ["OPENAI_API_KEY"] = api_key
# Create model and agent
model = LiteLLMModel(
model_id="gpt-4o-mini",
temperature=temperature
)
# Create agent with various tools
agent = CodeAgent(
tools=[analyze_dataframe, plot_data],
model=model,
additional_authorized_imports=[
"pandas",
"numpy",
"matplotlib",
"seaborn",
"plotly",
"sklearn",
"scipy"
],
max_steps=5,
verbosity_level=1
)
# Process uploaded file
df = process_file(file)
if df is None:
return "Error: Could not process uploaded file."
# Build context
file_info = f"""
Uploaded file: {file.name}
DataFrame Shape: {df.shape}
Columns: {', '.join(df.columns)}
Column Types:
{chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
"""
# Build prompt
prompt = f"""
{file_info}
The data has been loaded into a pandas DataFrame called 'df'.
Available tools:
- analyze_dataframe: Perform basic DataFrame analysis
- plot_data: Create various plots
Additional capabilities:
- Full pandas, numpy, matplotlib, seaborn access
- Machine learning with sklearn
- Statistical analysis with scipy
User request: {query}
Please analyze the data and provide:
1. A clear explanation of your approach
2. Code for the analysis
3. Visualizations where relevant
4. Key insights and findings
"""
# Run analysis
result = agent.run(prompt, additional_args={"df": df})
return result
except Exception as e:
return f"Error occurred: {str(e)}"
def create_interface():
"""Create Gradio interface"""
with gr.Blocks(title="AI Agent Testing Interface") as interface:
gr.Markdown("""
# AI Agent Testing Interface
Test the capabilities of AI agents using smolagents library. Upload data files and ask questions in natural language.
**Features:**
- Data analysis and visualization
- Machine learning capabilities
- Statistical analysis
- Custom tool integration
**Note**: Requires your own API key for GPT-4.
""")
with gr.Row():
with gr.Column():
file = gr.File(
label="Upload Data File (CSV/Excel)",
file_types=[".csv", ".xlsx", ".xls"]
)
query = gr.Textbox(
label="What would you like to analyze?",
placeholder="e.g., Analyze the relationships between variables and create visualizations",
lines=3
)
api_key = gr.Textbox(
label="API Key (Required)",
placeholder="Your API key",
type="password"
)
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1
)
analyze_btn = gr.Button("Analyze")
with gr.Column():
output = gr.Markdown(label="Output")
# Handle submissions
analyze_btn.click(
analyze_data,
inputs=[file, query, api_key, temperature],
outputs=output
)
# Example queries
gr.Examples(
examples=[
[None, "Perform comprehensive exploratory data analysis including distributions, correlations, and key statistics"],
[None, "Create visualizations showing relationships between numeric variables"],
[None, "Identify and analyze outliers in the dataset"],
[None, "Perform clustering analysis and visualize the results"],
[None, "Calculate summary statistics and create box plots for numeric columns"],
],
inputs=[file, query]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()