pavanmutha's picture
Update app.py
b9c72b0 verified
raw
history blame
8.67 kB
import os
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
import lime.lime_tabular
import optuna
import wandb
from smolagents import HfApiModel, CodeAgent
from huggingface_hub import login
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder
# Authenticate with Hugging Face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# SmolAgent initialization
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
df_global = None
def clean_data(df):
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
for col in df.select_dtypes(include='object').columns:
df[col] = df[col].astype(str)
df[col] = LabelEncoder().fit_transform(df[col])
df = df.fillna(df.mean(numeric_only=True))
return df
def upload_file(file):
global df_global
if file is None:
return pd.DataFrame({"Error": ["No file uploaded."]})
ext = os.path.splitext(file.name)[-1]
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
df = clean_data(df)
df_global = df
return df.head()
import textwrap
additional_notes = "Please note: Perform a comprehensive analysis including visualizations and insights."
# Initialize the agent
agent = CodeAgent(
tools=[],
model=model,
additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"]
)
def run_agent(_):
if df_global is None:
return "Please upload a file first.", []
from tempfile import NamedTemporaryFile
temp_file = NamedTemporaryFile(delete=False, suffix=".csv")
df_global.to_csv(temp_file.name, index=False)
temp_file.close()
prompt = """
You are an expert data analyst.
1. Load the provided dataset and analyze the structure.
2. Automatically detect key numeric and categorical columns.
3. Perform:
- Basic descriptive statistics.
- Null and duplicate checks.
- Insightful relationships between key columns.
- At least 3 visualizations showing important trends.
4. Derive at least 3 actionable real-world insights.
5. Save all visualizations to ./figures/ directory.
Return:
- A summary of the insights in clean bullet-point format.
- File paths of the generated visualizations.
"""
result = agent.run(prompt, additional_args={"source_file": temp_file.name})
image_paths = [line.strip() for line in result.splitlines() if line.strip().endswith(".png")]
insights = "\n".join([line for line in result.splitlines() if not line.strip().endswith(".png")])
return insights, image_paths
def train_model(_):
wandb.login(key=os.environ.get("WANDB_API_KEY"))
#wandb_run = wandb.init(project="huggingface-data-analysis", name="Optuna_Run", reinit=True)
# At the start of your script
run_counter = 1
# Then when initializing
wandb_run = wandb.init(project="huggingface-data-analysis", name=f"Optuna_Run_{run_counter}", reinit=True)
run_counter += 1
target = df_global.columns[-1]
X = df_global.drop(target, axis=1)
y = df_global[target]
if y.dtype == "object":
y = LabelEncoder().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
def objective(trial):
params = {
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
"max_depth": trial.suggest_int("max_depth", 3, 10),
}
model = RandomForestClassifier(**params)
score = cross_val_score(model, X_train, y_train, cv=3).mean()
wandb.log(params | {"cv_score": score})
return score
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=15)
best_params = study.best_params
model = RandomForestClassifier(**best_params)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
"recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
"f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
}
wandb.log(metrics)
wandb_run.finish()
top_trials = pd.DataFrame(study.trials_dataframe().sort_values(by="value", ascending=False).head(7))
return metrics, top_trials
def explainability(_):
import warnings
warnings.filterwarnings("ignore")
target = df_global.columns[-1]
X = df_global.drop(target, axis=1)
y = df_global[target]
if y.dtype == "object":
y = LabelEncoder().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = RandomForestClassifier()
model.fit(X_train, y_train)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
try:
if isinstance(shap_values, list):
class_idx = 0
sv = shap_values[class_idx]
else:
sv = shap_values
# Ensure 2D input shape for SHAP plot
if len(sv.shape) > 2:
sv = sv.reshape(sv.shape[0], -1) # Flatten any extra dimensions
# Use safe feature names if mismatch, fallback to dummy
num_features = sv.shape[1]
if num_features <= X_test.shape[1]:
feature_names = X_test.columns[:num_features]
else:
feature_names = [f"Feature_{i}" for i in range(num_features)]
X_shap_safe = pd.DataFrame(np.zeros_like(sv), columns=feature_names)
shap.summary_plot(sv, X_shap_safe, show=False)
shap_path = "./shap_plot.png"
plt.title("SHAP Summary")
plt.savefig(shap_path)
if wandb.run:
wandb.log({"shap_summary": wandb.Image(shap_path)})
plt.clf()
except Exception as e:
shap_path = "./shap_error.png"
print("SHAP plotting failed:", e)
plt.figure(figsize=(6, 3))
plt.text(0.5, 0.5, f"SHAP Error:\n{str(e)}", ha='center', va='center')
plt.axis('off')
plt.savefig(shap_path)
if wandb.run:
wandb.log({"shap_error": wandb.Image(shap_path)})
plt.clf()
# LIME
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
X_train.values,
feature_names=X_train.columns.tolist(),
class_names=[str(c) for c in np.unique(y_train)],
mode='classification'
)
lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
lime_fig = lime_exp.as_pyplot_figure()
lime_path = "./lime_plot.png"
lime_fig.savefig(lime_path)
if wandb.run:
wandb.log({"lime_explanation": wandb.Image(lime_path)})
plt.clf()
return shap_path, lime_path
with gr.Blocks() as demo:
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload CSV or Excel", type="filepath")
df_output = gr.DataFrame(label="Cleaned Data Preview")
file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
with gr.Column():
insights_output = gr.Textbox(label="Insights from SmolAgent", lines=15)
agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
with gr.Row():
train_btn = gr.Button("Train Model with Optuna + WandB")
metrics_output = gr.JSON(label="Performance Metrics")
trials_output = gr.DataFrame(label="Top 7 Hyperparameter Trials")
with gr.Row():
explain_btn = gr.Button("SHAP + LIME Explainability")
shap_img = gr.Image(label="SHAP Summary Plot")
lime_img = gr.Image(label="LIME Explanation")
with gr.Row():
agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
insights_output = gr.Textbox(label="Insights from SmolAgent", lines=15)
visual_output = gr.Gallery(label="Generated Visualizations").style(grid=3, height="auto")
#agent_btn.click(fn=run_agent, inputs=df_output, outputs=insights_output)
agent_btn.click(fn=run_agent, inputs=df_output, outputs=[insights_output, visual_output])
train_btn.click(fn=train_model, inputs=df_output, outputs=[metrics_output, trials_output])
explain_btn.click(fn=explainability, inputs=df_output, outputs=[shap_img, lime_img])
demo.launch(debug=True)