reab5555's picture
Update app.py
083effd verified
import gradio as gr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import io
import chardet
from PIL import Image
import numpy as np
def detect_encoding(file):
try:
with open(file, 'rb') as f:
raw = f.read(10000) # Read a chunk of the file
return chardet.detect(raw)['encoding']
except Exception as e:
print(f"Error detecting encoding: {str(e)}")
return 'utf-8' # Default to UTF-8 if detection fails
def create_plots(df, feature_columns, target_column):
plots = []
try:
# Check if the target column is numeric
is_numeric_target = pd.api.types.is_numeric_dtype(df[target_column])
# Determine the number of groups based on the number of feature columns
num_groups = 2 if len(feature_columns) > 3 else 1
# Split the features into groups
if num_groups == 2:
mid = len(feature_columns) // 2
feature_groups = [feature_columns[:mid], feature_columns[mid:]]
else:
feature_groups = [feature_columns]
for group, features in enumerate(feature_groups, 1):
# Add target to each feature set
features = features + [target_column]
# Create scatter plot
plt.figure(figsize=(12, 10))
try:
if is_numeric_target:
scatter_plot = sns.pairplot(df[features], kind='scatter',
plot_kws={'alpha': 0.6}, corner=True)
norm = plt.Normalize(df[target_column].min(), df[target_column].max())
for ax in scatter_plot.axes.flatten():
if ax.get_xlabel() != ax.get_ylabel() and ax.get_xlabel() is not None:
if len(ax.collections) > 0:
scatter = ax.collections[0]
scatter.set_cmap('viridis')
scatter.set_norm(norm)
scatter.set_array(df[target_column])
plt.colorbar(scatter, ax=ax, label=target_column)
else:
scatter_plot = sns.pairplot(df[features], hue=target_column, kind='scatter', corner=True)
scatter_plot.fig.suptitle(f'Scatter Plots - Group {group}', y=1.02, fontsize=16)
# Adjust label size and spacing
for ax in scatter_plot.axes.flatten():
ax.tick_params(labelsize=10)
if ax.get_xlabel():
ax.set_xlabel(ax.get_xlabel(), fontsize=12)
if ax.get_ylabel():
ax.set_ylabel(ax.get_ylabel(), fontsize=12)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=300)
buf.seek(0)
plots.append(buf)
except Exception as e:
print(f"Error in scatter plot for group {group}: {str(e)}")
finally:
plt.close()
# Create histogram plot
plt.figure(figsize=(12, 10))
try:
if is_numeric_target:
hist_plot = sns.pairplot(df[features], kind='hist',
plot_kws={'alpha': 0.6}, corner=True)
for ax in hist_plot.axes.flatten():
if ax.get_xlabel() == ax.get_ylabel() and ax.get_xlabel() is not None:
ax.clear()
sns.histplot(df[ax.get_xlabel()], ax=ax, kde=True)
elif ax.get_xlabel() is not None and ax.get_ylabel() is not None:
if len(ax.collections) > 0:
scatter = ax.collections[0]
scatter.set_cmap('viridis')
scatter.set_norm(norm)
scatter.set_array(df[target_column])
plt.colorbar(scatter, ax=ax, label=target_column)
else:
hist_plot = sns.pairplot(df[features], kind='hist', hue=target_column, corner=True)
hist_plot.fig.suptitle(f'Histogram Plots - Group {group}', y=1.02, fontsize=16)
# Adjust label size and spacing
for ax in hist_plot.axes.flatten():
ax.tick_params(labelsize=10)
if ax.get_xlabel():
ax.set_xlabel(ax.get_xlabel(), fontsize=12)
if ax.get_ylabel():
ax.set_ylabel(ax.get_ylabel(), fontsize=12)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=300)
buf.seek(0)
plots.append(buf)
except Exception as e:
print(f"Error in histogram plot for group {group}: {str(e)}")
finally:
plt.close()
# Create regression plot
n_features = len(features) - 1 # Exclude target column
fig, axes = plt.subplots(n_features, n_features, figsize=(16, 14))
fig.suptitle(f'Regression Plots - Group {group}', y=1.02, fontsize=16)
try:
for i, feature1 in enumerate(features[:-1]):
for j, feature2 in enumerate(features[:-1]):
if n_features == 1:
ax = axes
else:
ax = axes[i, j]
if i != j:
if is_numeric_target:
scatter = ax.scatter(df[feature1], df[feature2], c=df[target_column],
cmap='viridis', alpha=0.6)
plt.colorbar(scatter, ax=ax, label=target_column)
else:
sns.regplot(x=feature1, y=feature2, data=df, ax=ax,
scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'})
else:
sns.histplot(df[feature1], ax=ax, kde=True)
ax.set_xlabel(feature1, fontsize=10)
ax.set_ylabel(feature2, fontsize=10)
ax.tick_params(labelsize=8)
ax.set_title(f'{feature1} vs {feature2}', fontsize=12)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=300)
buf.seek(0)
plots.append(buf)
except Exception as e:
print(f"Error in regression plot for group {group}: {str(e)}")
finally:
plt.close()
# Calculate Pearson correlation values
correlation_matrix = df[feature_columns + [target_column]].corr()
# Create a heatmap of Pearson correlation values
plt.figure(figsize=(12, 10))
try:
heatmap = sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', square=True, cbar_kws={'shrink': .8})
heatmap.set_title('Pearson Correlation Heatmap', fontsize=16)
plt.xticks(rotation=45, ha='right', fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=300)
buf.seek(0)
plots.append(buf)
except Exception as e:
print(f"Error in correlation heatmap: {str(e)}")
finally:
plt.close()
except Exception as e:
print(f"Error in create_plots: {str(e)}")
return plots, num_groups
def process_csv(csv_file):
try:
if csv_file is not None:
encoding = detect_encoding(csv_file.name)
df = pd.read_csv(csv_file.name, encoding=encoding)
return gr.update(choices=df.columns.tolist()), gr.update(choices=df.columns.tolist())
return gr.update(), gr.update()
except Exception as e:
print(f"Error in process_csv: {str(e)}")
return gr.update(), gr.update()
def run_analysis(csv_file, feature_columns, target_column):
try:
if csv_file is None or feature_columns is None or target_column is None:
return [None] * 7
encoding = detect_encoding(csv_file.name)
df = pd.read_csv(csv_file.name, encoding=encoding)
plot_buffers, num_groups = create_plots(df, feature_columns, target_column)
# Convert BytesIO objects to PIL Images
images = [Image.open(buf) for buf in plot_buffers]
if num_groups == 1:
# If there's only one group, return 4 images (3 plots + heatmap)
while len(images) < 4:
images.append(None)
return images + [None] * 3
else:
# If there are two groups, return 7 images
while len(images) < 7:
images.append(None)
return images
except Exception as e:
print(f"Error in run_analysis: {str(e)}")
return [None] * 7
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Data Analysis Tool")
gr.Markdown("Upload a CSV file and select columns to generate plots.")
with gr.Row():
csv_file = gr.File(label="Upload CSV file")
feature_columns = gr.Dropdown(label="Select Feature Columns", multiselect=True)
target_column = gr.Dropdown(label="Select Target Column")
csv_file.upload(fn=process_csv, inputs=[csv_file], outputs=[feature_columns, target_column])
analyze_btn = gr.Button("Analyze")
with gr.Row():
plot1 = gr.Image(label="Scatter Plots - Group 1")
plot4 = gr.Image(label="Scatter Plots - Group 2")
with gr.Row():
plot2 = gr.Image(label="Histogram Plots - Group 1")
plot5 = gr.Image(label="Histogram Plots - Group 2")
with gr.Row():
plot3 = gr.Image(label="Regression Plots - Group 1")
plot6 = gr.Image(label="Regression Plots - Group 2")
with gr.Row():
heatmap = gr.Image(label="Pearson Correlation Heatmap")
analyze_btn.click(fn=run_analysis, inputs=[csv_file, feature_columns, target_column],
outputs=[plot1, plot4, plot2, plot5, plot3, plot6, heatmap])
# Launch the app
iface.launch()