import plotly.express as px import plotly.graph_objects as go import pandas as pd import numpy as np def generate_charts(df, profile): """ Generates a set of Plotly charts based on the data profile. Returns a Plotly Figure or a list/dict of figures (or a combined subplot for simplicity in Gradio). """ if df is None or df.empty: return None # We will create a few key charts figures = [] # 1. Correlation Heatmap (Numerical) num_cols = profile.get('numerical_columns', []) if len(num_cols) > 1: corr = df[num_cols].corr() fig_corr = px.imshow(corr, text_auto=True, aspect="auto", title="Correlation Matrix") figures.append(fig_corr) # 2. Distributions (Numerical) - Top 3 interesting ones (highest variance?) for col in num_cols[:3]: fig_hist = px.histogram(df, x=col, title=f"Distribution of {col}", marginal="box") figures.append(fig_hist) # 3. Categorical Counts - Top 3 # Filter out columns that are likely just distinct text (avg length > 20) unless unique count is very low cat_cols = profile.get('categorical_columns', []) valid_cat_cols = [] for col in cat_cols: # Check average length if object if df[col].dtype == 'object': avg_len = df[col].astype(str).str.len().mean() if avg_len > 30 and df[col].nunique() > 10: continue # Skip likely text/path columns valid_cat_cols.append(col) for col in valid_cat_cols[:3]: unique_count = df[col].nunique() # For layout, horizontal bars are better for text labels if unique_count < 50: counts = df[col].value_counts().head(15) # Truncate labels for display short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index] fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h', labels={'x': 'Count', 'y': col}, title=f"Top Categories in {col}") fig_bar.update_layout(yaxis=dict(autorange="reversed")) # Top to bottom figures.append(fig_bar) else: counts = df[col].value_counts().head(10) short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index] fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h', labels={'x': 'Count', 'y': col}, title=f"Top 10 Most Frequent in {col}") fig_bar.update_layout(yaxis=dict(autorange="reversed")) figures.append(fig_bar) # 4. Text Length Distribution (Fallback for Text Data) text_cols = [c for c in df.columns if df[c].dtype == 'object' and c not in cat_cols and df[c].str.len().mean() > 20] # Or just use heuristic from column names if not figures and text_cols: for col in text_cols[:2]: lengths = df[col].str.len() fig_len = px.histogram(lengths, title=f"Text Length Distribution: {col}", labels={'value': 'Character Count'}) figures.append(fig_len) # 5. Scatter Plots (if reasonable) if len(num_cols) >= 2: # Scatter of first two numerical columns fig_scat = px.scatter(df, x=num_cols[0], y=num_cols[1], title=f"{num_cols[0]} vs {num_cols[1]}") figures.append(fig_scat) # If absolutely NO charts could be generated if not figures: # Return a simple text figure or empty fig_empty = go.Figure() fig_empty.add_annotation(text="No visualizations available for this dataset structure.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) return fig_empty # Create a subplot figure import plotly.subplots as sp # Logic to pick 4 charts max charts_to_show = figures[:4] # Determine rows/cols based on count n_charts = len(charts_to_show) if n_charts == 1: return charts_to_show[0] # Return single figure directly rows = 2 if n_charts > 2 else 1 cols = 2 if n_charts > 1 else 1 # logic check: if n=2, rows=1, cols=2. if n=3, rows=2, cols=2. subplot_titles = [f.layout.title.text if f.layout.title.text else "" for f in charts_to_show] fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles) for i, f in enumerate(charts_to_show): row = (i // 2) + 1 if rows > 1 else 1 col = (i % 2) + 1 for trace in f.data: fig.add_trace(trace, row=row, col=col) fig.update_layout(height=400 * rows, title_text="Data Visualization Dashboard", showlegend=False) return fig