Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import io | |
| import chardet | |
| from PIL import Image | |
| import numpy as np | |
| def detect_encoding(file): | |
| try: | |
| with open(file, 'rb') as f: | |
| raw = f.read(10000) # Read a chunk of the file | |
| return chardet.detect(raw)['encoding'] | |
| except Exception as e: | |
| print(f"Error detecting encoding: {str(e)}") | |
| return 'utf-8' # Default to UTF-8 if detection fails | |
| def create_plots(df, feature_columns, target_column): | |
| plots = [] | |
| try: | |
| # Check if the target column is numeric | |
| is_numeric_target = pd.api.types.is_numeric_dtype(df[target_column]) | |
| # Determine the number of groups based on the number of feature columns | |
| num_groups = 2 if len(feature_columns) > 3 else 1 | |
| # Split the features into groups | |
| if num_groups == 2: | |
| mid = len(feature_columns) // 2 | |
| feature_groups = [feature_columns[:mid], feature_columns[mid:]] | |
| else: | |
| feature_groups = [feature_columns] | |
| for group, features in enumerate(feature_groups, 1): | |
| # Add target to each feature set | |
| features = features + [target_column] | |
| # Create scatter plot | |
| plt.figure(figsize=(12, 10)) | |
| try: | |
| if is_numeric_target: | |
| scatter_plot = sns.pairplot(df[features], kind='scatter', | |
| plot_kws={'alpha': 0.6}, corner=True) | |
| norm = plt.Normalize(df[target_column].min(), df[target_column].max()) | |
| for ax in scatter_plot.axes.flatten(): | |
| if ax.get_xlabel() != ax.get_ylabel() and ax.get_xlabel() is not None: | |
| if len(ax.collections) > 0: | |
| scatter = ax.collections[0] | |
| scatter.set_cmap('viridis') | |
| scatter.set_norm(norm) | |
| scatter.set_array(df[target_column]) | |
| plt.colorbar(scatter, ax=ax, label=target_column) | |
| else: | |
| scatter_plot = sns.pairplot(df[features], hue=target_column, kind='scatter', corner=True) | |
| scatter_plot.fig.suptitle(f'Scatter Plots - Group {group}', y=1.02, fontsize=16) | |
| # Adjust label size and spacing | |
| for ax in scatter_plot.axes.flatten(): | |
| ax.tick_params(labelsize=10) | |
| if ax.get_xlabel(): | |
| ax.set_xlabel(ax.get_xlabel(), fontsize=12) | |
| if ax.get_ylabel(): | |
| ax.set_ylabel(ax.get_ylabel(), fontsize=12) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300) | |
| buf.seek(0) | |
| plots.append(buf) | |
| except Exception as e: | |
| print(f"Error in scatter plot for group {group}: {str(e)}") | |
| finally: | |
| plt.close() | |
| # Create histogram plot | |
| plt.figure(figsize=(12, 10)) | |
| try: | |
| if is_numeric_target: | |
| hist_plot = sns.pairplot(df[features], kind='hist', | |
| plot_kws={'alpha': 0.6}, corner=True) | |
| for ax in hist_plot.axes.flatten(): | |
| if ax.get_xlabel() == ax.get_ylabel() and ax.get_xlabel() is not None: | |
| ax.clear() | |
| sns.histplot(df[ax.get_xlabel()], ax=ax, kde=True) | |
| elif ax.get_xlabel() is not None and ax.get_ylabel() is not None: | |
| if len(ax.collections) > 0: | |
| scatter = ax.collections[0] | |
| scatter.set_cmap('viridis') | |
| scatter.set_norm(norm) | |
| scatter.set_array(df[target_column]) | |
| plt.colorbar(scatter, ax=ax, label=target_column) | |
| else: | |
| hist_plot = sns.pairplot(df[features], kind='hist', hue=target_column, corner=True) | |
| hist_plot.fig.suptitle(f'Histogram Plots - Group {group}', y=1.02, fontsize=16) | |
| # Adjust label size and spacing | |
| for ax in hist_plot.axes.flatten(): | |
| ax.tick_params(labelsize=10) | |
| if ax.get_xlabel(): | |
| ax.set_xlabel(ax.get_xlabel(), fontsize=12) | |
| if ax.get_ylabel(): | |
| ax.set_ylabel(ax.get_ylabel(), fontsize=12) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300) | |
| buf.seek(0) | |
| plots.append(buf) | |
| except Exception as e: | |
| print(f"Error in histogram plot for group {group}: {str(e)}") | |
| finally: | |
| plt.close() | |
| # Create regression plot | |
| n_features = len(features) - 1 # Exclude target column | |
| fig, axes = plt.subplots(n_features, n_features, figsize=(16, 14)) | |
| fig.suptitle(f'Regression Plots - Group {group}', y=1.02, fontsize=16) | |
| try: | |
| for i, feature1 in enumerate(features[:-1]): | |
| for j, feature2 in enumerate(features[:-1]): | |
| if n_features == 1: | |
| ax = axes | |
| else: | |
| ax = axes[i, j] | |
| if i != j: | |
| if is_numeric_target: | |
| scatter = ax.scatter(df[feature1], df[feature2], c=df[target_column], | |
| cmap='viridis', alpha=0.6) | |
| plt.colorbar(scatter, ax=ax, label=target_column) | |
| else: | |
| sns.regplot(x=feature1, y=feature2, data=df, ax=ax, | |
| scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'}) | |
| else: | |
| sns.histplot(df[feature1], ax=ax, kde=True) | |
| ax.set_xlabel(feature1, fontsize=10) | |
| ax.set_ylabel(feature2, fontsize=10) | |
| ax.tick_params(labelsize=8) | |
| ax.set_title(f'{feature1} vs {feature2}', fontsize=12) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300) | |
| buf.seek(0) | |
| plots.append(buf) | |
| except Exception as e: | |
| print(f"Error in regression plot for group {group}: {str(e)}") | |
| finally: | |
| plt.close() | |
| # Calculate Pearson correlation values | |
| correlation_matrix = df[feature_columns + [target_column]].corr() | |
| # Create a heatmap of Pearson correlation values | |
| plt.figure(figsize=(12, 10)) | |
| try: | |
| heatmap = sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', square=True, cbar_kws={'shrink': .8}) | |
| heatmap.set_title('Pearson Correlation Heatmap', fontsize=16) | |
| plt.xticks(rotation=45, ha='right', fontsize=10) | |
| plt.yticks(fontsize=10) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=300) | |
| buf.seek(0) | |
| plots.append(buf) | |
| except Exception as e: | |
| print(f"Error in correlation heatmap: {str(e)}") | |
| finally: | |
| plt.close() | |
| except Exception as e: | |
| print(f"Error in create_plots: {str(e)}") | |
| return plots, num_groups | |
| def process_csv(csv_file): | |
| try: | |
| if csv_file is not None: | |
| encoding = detect_encoding(csv_file.name) | |
| df = pd.read_csv(csv_file.name, encoding=encoding) | |
| return gr.update(choices=df.columns.tolist()), gr.update(choices=df.columns.tolist()) | |
| return gr.update(), gr.update() | |
| except Exception as e: | |
| print(f"Error in process_csv: {str(e)}") | |
| return gr.update(), gr.update() | |
| def run_analysis(csv_file, feature_columns, target_column): | |
| try: | |
| if csv_file is None or feature_columns is None or target_column is None: | |
| return [None] * 7 | |
| encoding = detect_encoding(csv_file.name) | |
| df = pd.read_csv(csv_file.name, encoding=encoding) | |
| plot_buffers, num_groups = create_plots(df, feature_columns, target_column) | |
| # Convert BytesIO objects to PIL Images | |
| images = [Image.open(buf) for buf in plot_buffers] | |
| if num_groups == 1: | |
| # If there's only one group, return 4 images (3 plots + heatmap) | |
| while len(images) < 4: | |
| images.append(None) | |
| return images + [None] * 3 | |
| else: | |
| # If there are two groups, return 7 images | |
| while len(images) < 7: | |
| images.append(None) | |
| return images | |
| except Exception as e: | |
| print(f"Error in run_analysis: {str(e)}") | |
| return [None] * 7 | |
| # Create Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Data Analysis Tool") | |
| gr.Markdown("Upload a CSV file and select columns to generate plots.") | |
| with gr.Row(): | |
| csv_file = gr.File(label="Upload CSV file") | |
| feature_columns = gr.Dropdown(label="Select Feature Columns", multiselect=True) | |
| target_column = gr.Dropdown(label="Select Target Column") | |
| csv_file.upload(fn=process_csv, inputs=[csv_file], outputs=[feature_columns, target_column]) | |
| analyze_btn = gr.Button("Analyze") | |
| with gr.Row(): | |
| plot1 = gr.Image(label="Scatter Plots - Group 1") | |
| plot4 = gr.Image(label="Scatter Plots - Group 2") | |
| with gr.Row(): | |
| plot2 = gr.Image(label="Histogram Plots - Group 1") | |
| plot5 = gr.Image(label="Histogram Plots - Group 2") | |
| with gr.Row(): | |
| plot3 = gr.Image(label="Regression Plots - Group 1") | |
| plot6 = gr.Image(label="Regression Plots - Group 2") | |
| with gr.Row(): | |
| heatmap = gr.Image(label="Pearson Correlation Heatmap") | |
| analyze_btn.click(fn=run_analysis, inputs=[csv_file, feature_columns, target_column], | |
| outputs=[plot1, plot4, plot2, plot5, plot3, plot6, heatmap]) | |
| # Launch the app | |
| iface.launch() |