File size: 4,706 Bytes
bb9980b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841e95d
bb9980b
 
 
 
 
 
 
 
 
 
 
6b37a61
841e95d
6b37a61
 
 
 
 
 
 
 
 
 
 
841e95d
6b37a61
 
 
 
 
 
 
 
 
bb9980b
841e95d
 
6b37a61
 
 
 
841e95d
bb9980b
841e95d
 
 
 
 
 
 
 
 
 
bb9980b
 
 
 
 
841e95d
 
 
 
 
 
 
 
bbdd10b
 
bb9980b
bbdd10b
 
 
841e95d
 
 
 
bbdd10b
841e95d
 
bbdd10b
841e95d
 
bbdd10b
 
841e95d
bbdd10b
 
 
 
841e95d
bbdd10b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np

def generate_charts(df, profile):
    """
    Generates a set of Plotly charts based on the data profile.
    Returns a Plotly Figure or a list/dict of figures (or a combined subplot for simplicity in Gradio).
    """
    if df is None or df.empty:
        return None

    # We will create a few key charts
    figures = []
    
    # 1. Correlation Heatmap (Numerical)
    num_cols = profile.get('numerical_columns', [])
    if len(num_cols) > 1:
        corr = df[num_cols].corr()
        fig_corr = px.imshow(corr, text_auto=True, aspect="auto", title="Correlation Matrix")
        figures.append(fig_corr)

    # 2. Distributions (Numerical) - Top 3 interesting ones (highest variance?)
    for col in num_cols[:3]:
        fig_hist = px.histogram(df, x=col, title=f"Distribution of {col}", marginal="box")
        figures.append(fig_hist)

    # 3. Categorical Counts - Top 3
    # Filter out columns that are likely just distinct text (avg length > 20) unless unique count is very low
    cat_cols = profile.get('categorical_columns', [])
    valid_cat_cols = []
    
    for col in cat_cols:
        # Check average length if object
        if df[col].dtype == 'object':
             avg_len = df[col].astype(str).str.len().mean()
             if avg_len > 30 and df[col].nunique() > 10:
                 continue # Skip likely text/path columns
        valid_cat_cols.append(col)

    for col in valid_cat_cols[:3]:
        unique_count = df[col].nunique()
        # For layout, horizontal bars are better for text labels
        if unique_count < 50: 
            counts = df[col].value_counts().head(15)
            # Truncate labels for display
            short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index]
            
            fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h', 
                             labels={'x': 'Count', 'y': col}, title=f"Top Categories in {col}")
            fig_bar.update_layout(yaxis=dict(autorange="reversed")) # Top to bottom
            figures.append(fig_bar)
        else:
             counts = df[col].value_counts().head(10)
             short_labels = [str(x)[:30] + "..." if len(str(x)) > 30 else str(x) for x in counts.index]
             fig_bar = px.bar(x=counts.values, y=short_labels, orientation='h', 
                              labels={'x': 'Count', 'y': col}, title=f"Top 10 Most Frequent in {col}")
             fig_bar.update_layout(yaxis=dict(autorange="reversed"))
             figures.append(fig_bar)

    # 4. Text Length Distribution (Fallback for Text Data)
    text_cols = [c for c in df.columns if df[c].dtype == 'object' and c not in cat_cols and df[c].str.len().mean() > 20] 
    # Or just use heuristic from column names
    if not figures and text_cols:
         for col in text_cols[:2]:
             lengths = df[col].str.len()
             fig_len = px.histogram(lengths, title=f"Text Length Distribution: {col}", labels={'value': 'Character Count'})
             figures.append(fig_len)

    # 5. Scatter Plots (if reasonable)
    if len(num_cols) >= 2:
        # Scatter of first two numerical columns
        fig_scat = px.scatter(df, x=num_cols[0], y=num_cols[1], title=f"{num_cols[0]} vs {num_cols[1]}")
        figures.append(fig_scat)

    # If absolutely NO charts could be generated
    if not figures:
        # Return a simple text figure or empty
        fig_empty = go.Figure()
        fig_empty.add_annotation(text="No visualizations available for this dataset structure.", 
                                 xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
        return fig_empty

    # Create a subplot figure
    import plotly.subplots as sp
    
    # Logic to pick 4 charts max
    charts_to_show = figures[:4]
    
    # Determine rows/cols based on count
    n_charts = len(charts_to_show)
    if n_charts == 1:
        return charts_to_show[0] # Return single figure directly
    
    rows = 2 if n_charts > 2 else 1
    cols = 2 if n_charts > 1 else 1 # logic check: if n=2, rows=1, cols=2. if n=3, rows=2, cols=2.
    
    subplot_titles = [f.layout.title.text if f.layout.title.text else "" for f in charts_to_show]
    fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles)
    
    for i, f in enumerate(charts_to_show):
        row = (i // 2) + 1 if rows > 1 else 1
        col = (i % 2) + 1
        for trace in f.data:
            fig.add_trace(trace, row=row, col=col)
            
    fig.update_layout(height=400 * rows, title_text="Data Visualization Dashboard", showlegend=False)
    return fig