auto-data-analyst / src /visualization.py
salihfurkaan's picture
Refine visualizations: horizontal bars, smart filtering, label truncation
6b37a61
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
def generate_charts(df, profile):
"""
Generates a set of Plotly charts based on the data profile.
Returns a Plotly Figure or a list/dict of figures (or a combined subplot for simplicity in Gradio).
"""
if df is None or df.empty:
return None
# We will create a few key charts
figures = []
# 1. Correlation Heatmap (Numerical)
num_cols = profile.get('numerical_columns', [])
if len(num_cols) > 1:
corr = df[num_cols].corr()
fig_corr = px.imshow(corr, text_auto=True, aspect="auto", title="Correlation Matrix")
figures.append(fig_corr)
# 2. Distributions (Numerical) - Top 3 interesting ones (highest variance?)
for col in num_cols[:3]:
fig_hist = px.histogram(df, x=col, title=f"Distribution of {col}", marginal="box")
figures.append(fig_hist)
# 3. Categorical Counts - Top 3
# Filter out columns that are likely just distinct text (avg length > 20) unless unique count is very low
cat_cols = profile.get('categorical_columns', [])
valid_cat_cols = []
for col in cat_cols:
# Check average length if object
if df[col].dtype == 'object':
avg_len = df[col].astype(str).str.len().mean()
if avg_len > 30 and df[col].nunique() > 10:
continue # Skip likely text/path columns
valid_cat_cols.append(col)
for col in valid_cat_cols[:3]:
unique_count = df[col].nunique()
# For layout, horizontal bars are better for text labels
if unique_count < 50:
counts = df[col].value_counts().head(15)
# Truncate labels for display
short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index]
fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h',
labels={'x': 'Count', 'y': col}, title=f"Top Categories in {col}")
fig_bar.update_layout(yaxis=dict(autorange="reversed")) # Top to bottom
figures.append(fig_bar)
else:
counts = df[col].value_counts().head(10)
short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index]
fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h',
labels={'x': 'Count', 'y': col}, title=f"Top 10 Most Frequent in {col}")
fig_bar.update_layout(yaxis=dict(autorange="reversed"))
figures.append(fig_bar)
# 4. Text Length Distribution (Fallback for Text Data)
text_cols = [c for c in df.columns if df[c].dtype == 'object' and c not in cat_cols and df[c].str.len().mean() > 20]
# Or just use heuristic from column names
if not figures and text_cols:
for col in text_cols[:2]:
lengths = df[col].str.len()
fig_len = px.histogram(lengths, title=f"Text Length Distribution: {col}", labels={'value': 'Character Count'})
figures.append(fig_len)
# 5. Scatter Plots (if reasonable)
if len(num_cols) >= 2:
# Scatter of first two numerical columns
fig_scat = px.scatter(df, x=num_cols[0], y=num_cols[1], title=f"{num_cols[0]} vs {num_cols[1]}")
figures.append(fig_scat)
# If absolutely NO charts could be generated
if not figures:
# Return a simple text figure or empty
fig_empty = go.Figure()
fig_empty.add_annotation(text="No visualizations available for this dataset structure.",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
return fig_empty
# Create a subplot figure
import plotly.subplots as sp
# Logic to pick 4 charts max
charts_to_show = figures[:4]
# Determine rows/cols based on count
n_charts = len(charts_to_show)
if n_charts == 1:
return charts_to_show[0] # Return single figure directly
rows = 2 if n_charts > 2 else 1
cols = 2 if n_charts > 1 else 1 # logic check: if n=2, rows=1, cols=2. if n=3, rows=2, cols=2.
subplot_titles = [f.layout.title.text if f.layout.title.text else "" for f in charts_to_show]
fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles)
for i, f in enumerate(charts_to_show):
row = (i // 2) + 1 if rows > 1 else 1
col = (i % 2) + 1
for trace in f.data:
fig.add_trace(trace, row=row, col=col)
fig.update_layout(height=400 * rows, title_text="Data Visualization Dashboard", showlegend=False)
return fig