Spaces:
Sleeping
Sleeping
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| import numpy as np | |
| def generate_charts(df, profile): | |
| """ | |
| Generates a set of Plotly charts based on the data profile. | |
| Returns a Plotly Figure or a list/dict of figures (or a combined subplot for simplicity in Gradio). | |
| """ | |
| if df is None or df.empty: | |
| return None | |
| # We will create a few key charts | |
| figures = [] | |
| # 1. Correlation Heatmap (Numerical) | |
| num_cols = profile.get('numerical_columns', []) | |
| if len(num_cols) > 1: | |
| corr = df[num_cols].corr() | |
| fig_corr = px.imshow(corr, text_auto=True, aspect="auto", title="Correlation Matrix") | |
| figures.append(fig_corr) | |
| # 2. Distributions (Numerical) - Top 3 interesting ones (highest variance?) | |
| for col in num_cols[:3]: | |
| fig_hist = px.histogram(df, x=col, title=f"Distribution of {col}", marginal="box") | |
| figures.append(fig_hist) | |
| # 3. Categorical Counts - Top 3 | |
| # Filter out columns that are likely just distinct text (avg length > 20) unless unique count is very low | |
| cat_cols = profile.get('categorical_columns', []) | |
| valid_cat_cols = [] | |
| for col in cat_cols: | |
| # Check average length if object | |
| if df[col].dtype == 'object': | |
| avg_len = df[col].astype(str).str.len().mean() | |
| if avg_len > 30 and df[col].nunique() > 10: | |
| continue # Skip likely text/path columns | |
| valid_cat_cols.append(col) | |
| for col in valid_cat_cols[:3]: | |
| unique_count = df[col].nunique() | |
| # For layout, horizontal bars are better for text labels | |
| if unique_count < 50: | |
| counts = df[col].value_counts().head(15) | |
| # Truncate labels for display | |
| short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index] | |
| fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h', | |
| labels={'x': 'Count', 'y': col}, title=f"Top Categories in {col}") | |
| fig_bar.update_layout(yaxis=dict(autorange="reversed")) # Top to bottom | |
| figures.append(fig_bar) | |
| else: | |
| counts = df[col].value_counts().head(10) | |
| short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index] | |
| fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h', | |
| labels={'x': 'Count', 'y': col}, title=f"Top 10 Most Frequent in {col}") | |
| fig_bar.update_layout(yaxis=dict(autorange="reversed")) | |
| figures.append(fig_bar) | |
| # 4. Text Length Distribution (Fallback for Text Data) | |
| text_cols = [c for c in df.columns if df[c].dtype == 'object' and c not in cat_cols and df[c].str.len().mean() > 20] | |
| # Or just use heuristic from column names | |
| if not figures and text_cols: | |
| for col in text_cols[:2]: | |
| lengths = df[col].str.len() | |
| fig_len = px.histogram(lengths, title=f"Text Length Distribution: {col}", labels={'value': 'Character Count'}) | |
| figures.append(fig_len) | |
| # 5. Scatter Plots (if reasonable) | |
| if len(num_cols) >= 2: | |
| # Scatter of first two numerical columns | |
| fig_scat = px.scatter(df, x=num_cols[0], y=num_cols[1], title=f"{num_cols[0]} vs {num_cols[1]}") | |
| figures.append(fig_scat) | |
| # If absolutely NO charts could be generated | |
| if not figures: | |
| # Return a simple text figure or empty | |
| fig_empty = go.Figure() | |
| fig_empty.add_annotation(text="No visualizations available for this dataset structure.", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) | |
| return fig_empty | |
| # Create a subplot figure | |
| import plotly.subplots as sp | |
| # Logic to pick 4 charts max | |
| charts_to_show = figures[:4] | |
| # Determine rows/cols based on count | |
| n_charts = len(charts_to_show) | |
| if n_charts == 1: | |
| return charts_to_show[0] # Return single figure directly | |
| rows = 2 if n_charts > 2 else 1 | |
| cols = 2 if n_charts > 1 else 1 # logic check: if n=2, rows=1, cols=2. if n=3, rows=2, cols=2. | |
| subplot_titles = [f.layout.title.text if f.layout.title.text else "" for f in charts_to_show] | |
| fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles) | |
| for i, f in enumerate(charts_to_show): | |
| row = (i // 2) + 1 if rows > 1 else 1 | |
| col = (i % 2) + 1 | |
| for trace in f.data: | |
| fig.add_trace(trace, row=row, col=col) | |
| fig.update_layout(height=400 * rows, title_text="Data Visualization Dashboard", showlegend=False) | |
| return fig | |