|
|
import os |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import shap |
|
|
import lime.lime_tabular |
|
|
import optuna |
|
|
import wandb |
|
|
import json |
|
|
import time |
|
|
import psutil |
|
|
import shutil |
|
|
import ast |
|
|
from smolagents import HfApiModel, CodeAgent |
|
|
from huggingface_hub import login |
|
|
from sklearn.model_selection import train_test_split, cross_val_score |
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
|
|
from sklearn.metrics import ConfusionMatrixDisplay |
|
|
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
from datetime import datetime |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
login(token=hf_token) |
|
|
|
|
|
|
|
|
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token) |
|
|
|
|
|
df_global = None |
|
|
target_column_global = None |
|
|
|
|
|
def clean_data(df): |
|
|
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0) |
|
|
for col in df.select_dtypes(include='object').columns: |
|
|
df[col] = df[col].astype(str) |
|
|
df[col] = LabelEncoder().fit_transform(df[col]) |
|
|
df = df.fillna(df.mean(numeric_only=True)) |
|
|
return df |
|
|
|
|
|
def upload_file(file): |
|
|
global df_global |
|
|
if file is None: |
|
|
return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[]) |
|
|
ext = os.path.splitext(file.name)[-1] |
|
|
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name) |
|
|
df = clean_data(df) |
|
|
df_global = df |
|
|
return df.head(), gr.update(choices=df.columns.tolist()) |
|
|
|
|
|
|
|
|
|
|
|
def set_target_column(col_name): |
|
|
global target_column_global |
|
|
target_column_global = col_name |
|
|
return f"β
Target column set to: {col_name}" |
|
|
|
|
|
|
|
|
def format_analysis_report(raw_output, visuals): |
|
|
try: |
|
|
if isinstance(raw_output, dict): |
|
|
analysis_dict = raw_output |
|
|
else: |
|
|
try: |
|
|
analysis_dict = ast.literal_eval(str(raw_output)) |
|
|
except (SyntaxError, ValueError) as e: |
|
|
print(f"Error parsing CodeAgent output: {e}") |
|
|
return str(raw_output), visuals |
|
|
|
|
|
report = f""" |
|
|
<div style="font-family: Arial, sans-serif; padding: 20px; color: #333;"> |
|
|
<h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">π Data Analysis Report</h1> |
|
|
<div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;"> |
|
|
<h2 style="color: #2B547E;">π Key Observations</h2> |
|
|
{format_observations(analysis_dict.get('observations', {}))} |
|
|
</div> |
|
|
<div style="margin-top: 30px;"> |
|
|
<h2 style="color: #2B547E;">π‘ Insights & Visualizations</h2> |
|
|
{format_insights(analysis_dict.get('insights', {}), visuals)} |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
return report, visuals |
|
|
except Exception as e: |
|
|
print(f"Error in format_analysis_report: {e}") |
|
|
return str(raw_output), visuals |
|
|
|
|
|
def format_observations(observations): |
|
|
return '\n'.join([ |
|
|
f""" |
|
|
<div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);"> |
|
|
<h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3> |
|
|
<pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre> |
|
|
</div> |
|
|
""" for key, value in observations.items() if 'proportions' in key |
|
|
]) |
|
|
|
|
|
def format_insights(insights, visuals): |
|
|
return '\n'.join([ |
|
|
f""" |
|
|
<div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);"> |
|
|
<div style="display: flex; align-items: center; gap: 10px;"> |
|
|
<div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div> |
|
|
<p style="margin: 0; font-size: 16px;">{insight}</p> |
|
|
</div> |
|
|
{f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''} |
|
|
</div> |
|
|
""" for idx, (key, insight) in enumerate(insights.items()) |
|
|
]) |
|
|
|
|
|
def analyze_data(csv_file, additional_notes=""): |
|
|
start_time = time.time() |
|
|
process = psutil.Process(os.getpid()) |
|
|
initial_memory = process.memory_info().rss / 1024 ** 2 |
|
|
|
|
|
if os.path.exists('./figures'): |
|
|
shutil.rmtree('./figures') |
|
|
os.makedirs('./figures', exist_ok=True) |
|
|
|
|
|
wandb.login(key=os.environ.get('WANDB_API_KEY')) |
|
|
run = wandb.init(project="huggingface-data-analysis", config={ |
|
|
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
|
"additional_notes": additional_notes, |
|
|
"source_file": csv_file.name if csv_file else None |
|
|
}) |
|
|
|
|
|
agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"]) |
|
|
analysis_result = agent.run(""" |
|
|
You are a helpful data analysis agent. Just return insight information and visualization. |
|
|
Load the data that is passed.do not create your own. |
|
|
Automatically detect numeric columns and names. |
|
|
2. 5 data visualizations |
|
|
3. at least 5 insights from data |
|
|
5. Generate publication-quality visualizations and save to './figures/'. |
|
|
Do not use 'open()' or write to files. Just return variables and plots. |
|
|
The dictionary should have the following structure: |
|
|
{ |
|
|
'observations': { |
|
|
'observation_1_key': 'observation_1_value', |
|
|
'observation_2_key': 'observation_2_value', |
|
|
... |
|
|
}, |
|
|
'insights': { |
|
|
'insight_1_key': 'insight_1_value', |
|
|
'insight_2_key': 'insight_2_value', |
|
|
... |
|
|
} |
|
|
} |
|
|
""", additional_args={"additional_notes": additional_notes, "source_file": csv_file}) |
|
|
|
|
|
execution_time = time.time() - start_time |
|
|
final_memory = process.memory_info().rss / 1024 ** 2 |
|
|
memory_usage = final_memory - initial_memory |
|
|
wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage}) |
|
|
|
|
|
visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))] |
|
|
for viz in visuals: |
|
|
wandb.log({os.path.basename(viz): wandb.Image(viz)}) |
|
|
|
|
|
run.finish() |
|
|
return format_analysis_report(analysis_result, visuals) |
|
|
|
|
|
def compare_models(): |
|
|
import seaborn as sns |
|
|
from sklearn.model_selection import cross_val_predict |
|
|
|
|
|
if df_global is None: |
|
|
return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None |
|
|
|
|
|
global target_column_global |
|
|
target = target_column_global |
|
|
X = df_global.drop(target, axis=1) |
|
|
y = df_global[target] |
|
|
|
|
|
if y.dtype == 'object': |
|
|
y = LabelEncoder().fit_transform(y) |
|
|
|
|
|
models = { |
|
|
"RandomForest": RandomForestClassifier(), |
|
|
"LogisticRegression": LogisticRegression(max_iter=1000), |
|
|
"GradientBoosting": GradientBoostingClassifier() |
|
|
} |
|
|
|
|
|
results = [] |
|
|
for name, model in models.items(): |
|
|
|
|
|
scores = cross_val_score(model, X, y, cv=5) |
|
|
|
|
|
|
|
|
y_pred = cross_val_predict(model, X, y, cv=5) |
|
|
|
|
|
metrics = { |
|
|
"Model": name, |
|
|
"CV Mean Accuracy": np.mean(scores), |
|
|
"CV Std Dev": np.std(scores), |
|
|
"F1 Score": f1_score(y, y_pred, average="weighted", zero_division=0), |
|
|
"Precision": precision_score(y, y_pred, average="weighted", zero_division=0), |
|
|
"Recall": recall_score(y, y_pred, average="weighted", zero_division=0), |
|
|
} |
|
|
if wandb.run is None: |
|
|
wandb.init(project="model_comparison", name="compare_models", reinit=True) |
|
|
wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))}) |
|
|
results.append(metrics) |
|
|
|
|
|
results_df = pd.DataFrame(results) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 5)) |
|
|
sns.barplot(data=results_df, x="Model", y="CV Mean Accuracy", palette="Blues_d") |
|
|
plt.title("Model Comparison (CV Mean Accuracy)") |
|
|
plt.ylim(0, 1) |
|
|
plt.tight_layout() |
|
|
|
|
|
plot_path = "./model_comparison.png" |
|
|
plt.savefig(plot_path) |
|
|
plt.close() |
|
|
|
|
|
return results_df, plot_path |
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(df): |
|
|
global target_column_global |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
if target_column is None: |
|
|
raise ValueError("Target column not set.") |
|
|
|
|
|
X = df.drop(columns=[target_column_global]) |
|
|
y = df[target_column_global] |
|
|
|
|
|
return train_test_split(X, y, test_size=0.3, random_state=42) |
|
|
|
|
|
|
|
|
def train_model(_): |
|
|
try: |
|
|
wandb.login(key=os.environ.get("WANDB_API_KEY")) |
|
|
wandb_run = wandb.init( |
|
|
project="huggingface-data-analysis", |
|
|
name=f"Optuna_Run_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
|
|
reinit=True |
|
|
) |
|
|
|
|
|
X_train, X_test, y_train, y_test = prepare_data(df_global) |
|
|
|
|
|
def objective(trial): |
|
|
params = { |
|
|
"n_estimators": trial.suggest_int("n_estimators", 50, 200), |
|
|
"max_depth": trial.suggest_int("max_depth", 3, 10), |
|
|
} |
|
|
model = RandomForestClassifier(**params) |
|
|
score = cross_val_score(model, X_train, y_train, cv=3).mean() |
|
|
if wandb.run is None: |
|
|
wandb.init(project="model_optimization", name=f"optuna_trial_{trial.number}", reinit=True) |
|
|
wandb.log({**params, "cv_score": score}) |
|
|
return score |
|
|
|
|
|
study = optuna.create_study(direction="maximize") |
|
|
study.optimize(objective, n_trials=15) |
|
|
|
|
|
best_params = study.best_params |
|
|
model = RandomForestClassifier() |
|
|
model.fit(X_train, y_train) |
|
|
y_pred = model.predict(X_test) |
|
|
|
|
|
|
|
|
metrics = { |
|
|
"accuracy": accuracy_score(y_test, y_pred), |
|
|
"precision": precision_score(y_test, y_pred, average="weighted", zero_division=0), |
|
|
"recall": recall_score(y_test, y_pred, average="weighted", zero_division=0), |
|
|
"f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0), |
|
|
} |
|
|
wandb.log(metrics) |
|
|
wandb_run.finish() |
|
|
|
|
|
|
|
|
top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7] |
|
|
trial_rows = [] |
|
|
for t in top_trials: |
|
|
row = t.params.copy() |
|
|
row["score"] = t.value |
|
|
trial_rows.append(row) |
|
|
trials_df = pd.DataFrame(trial_rows) |
|
|
|
|
|
return metrics, trials_df |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Training Error: {e}") |
|
|
return {}, pd.DataFrame() |
|
|
|
|
|
|
|
|
def explainability(_): |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
global target_column_global |
|
|
target = target_column_global |
|
|
X = df_global.drop(target, axis=1) |
|
|
y = df_global[target] |
|
|
|
|
|
if y.dtype == "object": |
|
|
y = LabelEncoder().fit_transform(y) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) |
|
|
|
|
|
model = RandomForestClassifier() |
|
|
model.fit(X_train, y_train) |
|
|
|
|
|
explainer = shap.TreeExplainer(model) |
|
|
shap_values = explainer.shap_values(X_test) |
|
|
|
|
|
try: |
|
|
if isinstance(shap_values, list): |
|
|
class_idx = 0 |
|
|
sv = shap_values[class_idx] |
|
|
else: |
|
|
sv = shap_values |
|
|
|
|
|
|
|
|
if len(sv.shape) > 2: |
|
|
sv = sv.reshape(sv.shape[0], -1) |
|
|
|
|
|
|
|
|
num_features = sv.shape[1] |
|
|
if num_features <= X_test.shape[1]: |
|
|
feature_names = X_test.columns[:num_features] |
|
|
else: |
|
|
feature_names = [f"Feature_{i}" for i in range(num_features)] |
|
|
|
|
|
X_shap_safe = pd.DataFrame(np.zeros_like(sv), columns=feature_names) |
|
|
|
|
|
shap.summary_plot(sv, X_shap_safe, show=False) |
|
|
shap_path = "./shap_plot.png" |
|
|
plt.title("SHAP Summary") |
|
|
plt.savefig(shap_path) |
|
|
if wandb.run: |
|
|
wandb.log({"shap_summary": wandb.Image(shap_path)}) |
|
|
plt.clf() |
|
|
|
|
|
except Exception as e: |
|
|
shap_path = "./shap_error.png" |
|
|
print("SHAP plotting failed:", e) |
|
|
plt.figure(figsize=(6, 3)) |
|
|
plt.text(0.5, 0.5, f"SHAP Error:\n{str(e)}", ha='center', va='center') |
|
|
plt.axis('off') |
|
|
plt.savefig(shap_path) |
|
|
if wandb.run: |
|
|
wandb.log({"shap_error": wandb.Image(shap_path)}) |
|
|
plt.clf() |
|
|
|
|
|
|
|
|
lime_explainer = lime.lime_tabular.LimeTabularExplainer( |
|
|
X_train.values, |
|
|
feature_names=X_train.columns.tolist(), |
|
|
class_names=[str(c) for c in np.unique(y_train)], |
|
|
mode='classification' |
|
|
) |
|
|
lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba) |
|
|
lime_fig = lime_exp.as_pyplot_figure() |
|
|
lime_path = "./lime_plot.png" |
|
|
lime_fig.savefig(lime_path) |
|
|
if wandb.run: |
|
|
wandb.log({"lime_explanation": wandb.Image(lime_path)}) |
|
|
plt.clf() |
|
|
|
|
|
return shap_path, lime_path |
|
|
|
|
|
|
|
|
|
|
|
def update_target_choices(): |
|
|
global df_global |
|
|
if df_global is not None: |
|
|
return gr.update(choices=df_global.columns.tolist()) |
|
|
else: |
|
|
return gr.update(choices=[]) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## π AI-Powered Data Analysis with Hyperparameter Optimization") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input = gr.File(label="Upload CSV or Excel", type="filepath") |
|
|
df_output = gr.DataFrame(label="Cleaned Data Preview") |
|
|
target_dropdown = gr.Dropdown(label="Select Target Column", choices=[], interactive=True) |
|
|
target_status = gr.Textbox(label="Target Column Status", interactive=False) |
|
|
|
|
|
file_input.change(fn=upload_file, inputs=file_input, outputs=[df_output, target_dropdown]) |
|
|
|
|
|
target_dropdown.change(fn=set_target_column, inputs=target_dropdown, outputs=target_status) |
|
|
|
|
|
with gr.Column(): |
|
|
insights_output = gr.HTML(label="Insights from SmolAgent") |
|
|
visual_output = gr.Gallery(label="Visualizations (Auto-generated by Agent)", columns=2) |
|
|
agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)") |
|
|
|
|
|
with gr.Row(): |
|
|
train_btn = gr.Button("Train Model with Optuna + WandB") |
|
|
metrics_output = gr.JSON(label="Performance Metrics") |
|
|
trials_output = gr.DataFrame(label="Top 7 Hyperparameter Trials") |
|
|
|
|
|
with gr.Row(): |
|
|
explain_btn = gr.Button("SHAP + LIME Explainability") |
|
|
shap_img = gr.Image(label="SHAP Summary Plot") |
|
|
lime_img = gr.Image(label="LIME Explanation") |
|
|
|
|
|
with gr.Row(): |
|
|
compare_btn = gr.Button("Compare Models (A/B Testing)") |
|
|
compare_output = gr.DataFrame(label="Model Comparison (CV + Metrics)") |
|
|
compare_img = gr.Image(label="Model Accuracy Plot") |
|
|
|
|
|
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output]) |
|
|
train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output]) |
|
|
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img]) |
|
|
compare_btn.click(fn=compare_models, inputs=[], outputs=[compare_output, compare_img]) |
|
|
|
|
|
demo.launch(debug=True) |