Spaces:
Sleeping
Sleeping
File size: 4,706 Bytes
bb9980b 841e95d bb9980b 6b37a61 841e95d 6b37a61 841e95d 6b37a61 bb9980b 841e95d 6b37a61 841e95d bb9980b 841e95d bb9980b 841e95d bbdd10b bb9980b bbdd10b 841e95d bbdd10b 841e95d bbdd10b 841e95d bbdd10b 841e95d bbdd10b 841e95d bbdd10b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
def generate_charts(df, profile):
"""
Generates a set of Plotly charts based on the data profile.
Returns a Plotly Figure or a list/dict of figures (or a combined subplot for simplicity in Gradio).
"""
if df is None or df.empty:
return None
# We will create a few key charts
figures = []
# 1. Correlation Heatmap (Numerical)
num_cols = profile.get('numerical_columns', [])
if len(num_cols) > 1:
corr = df[num_cols].corr()
fig_corr = px.imshow(corr, text_auto=True, aspect="auto", title="Correlation Matrix")
figures.append(fig_corr)
# 2. Distributions (Numerical) - Top 3 interesting ones (highest variance?)
for col in num_cols[:3]:
fig_hist = px.histogram(df, x=col, title=f"Distribution of {col}", marginal="box")
figures.append(fig_hist)
# 3. Categorical Counts - Top 3
# Filter out columns that are likely just distinct text (avg length > 20) unless unique count is very low
cat_cols = profile.get('categorical_columns', [])
valid_cat_cols = []
for col in cat_cols:
# Check average length if object
if df[col].dtype == 'object':
avg_len = df[col].astype(str).str.len().mean()
if avg_len > 30 and df[col].nunique() > 10:
continue # Skip likely text/path columns
valid_cat_cols.append(col)
for col in valid_cat_cols[:3]:
unique_count = df[col].nunique()
# For layout, horizontal bars are better for text labels
if unique_count < 50:
counts = df[col].value_counts().head(15)
# Truncate labels for display
short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index]
fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h',
labels={'x': 'Count', 'y': col}, title=f"Top Categories in {col}")
fig_bar.update_layout(yaxis=dict(autorange="reversed")) # Top to bottom
figures.append(fig_bar)
else:
counts = df[col].value_counts().head(10)
short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index]
fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h',
labels={'x': 'Count', 'y': col}, title=f"Top 10 Most Frequent in {col}")
fig_bar.update_layout(yaxis=dict(autorange="reversed"))
figures.append(fig_bar)
# 4. Text Length Distribution (Fallback for Text Data)
text_cols = [c for c in df.columns if df[c].dtype == 'object' and c not in cat_cols and df[c].str.len().mean() > 20]
# Or just use heuristic from column names
if not figures and text_cols:
for col in text_cols[:2]:
lengths = df[col].str.len()
fig_len = px.histogram(lengths, title=f"Text Length Distribution: {col}", labels={'value': 'Character Count'})
figures.append(fig_len)
# 5. Scatter Plots (if reasonable)
if len(num_cols) >= 2:
# Scatter of first two numerical columns
fig_scat = px.scatter(df, x=num_cols[0], y=num_cols[1], title=f"{num_cols[0]} vs {num_cols[1]}")
figures.append(fig_scat)
# If absolutely NO charts could be generated
if not figures:
# Return a simple text figure or empty
fig_empty = go.Figure()
fig_empty.add_annotation(text="No visualizations available for this dataset structure.",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
return fig_empty
# Create a subplot figure
import plotly.subplots as sp
# Logic to pick 4 charts max
charts_to_show = figures[:4]
# Determine rows/cols based on count
n_charts = len(charts_to_show)
if n_charts == 1:
return charts_to_show[0] # Return single figure directly
rows = 2 if n_charts > 2 else 1
cols = 2 if n_charts > 1 else 1 # logic check: if n=2, rows=1, cols=2. if n=3, rows=2, cols=2.
subplot_titles = [f.layout.title.text if f.layout.title.text else "" for f in charts_to_show]
fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles)
for i, f in enumerate(charts_to_show):
row = (i // 2) + 1 if rows > 1 else 1
col = (i % 2) + 1
for trace in f.data:
fig.add_trace(trace, row=row, col=col)
fig.update_layout(height=400 * rows, title_text="Data Visualization Dashboard", showlegend=False)
return fig
|