File size: 2,625 Bytes
5d4981c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import streamlit as st
import plotly.graph_objects as go # type: ignore
from plotly.subplots import make_subplots # type: ignore
from sklearn.decomposition import NMF


    # --------- Plot helper (Plotly) ---------
def _build_topic_figure(
    model: NMF,
    feature_names: np.ndarray,
    n_top_words: int,
    title: str,
    n_components: int,
    bar_color: str
) -> go.Figure:
    """Create a Plotly subplot grid of top terms per topic (horizontal bars)."""
    # Layout: up to 2 columns, as many rows as needed
    cols = 2 if n_components > 3 else 1
    rows = int(np.ceil(n_components / cols))

    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=[f"Topic {i+1}" for i in range(n_components)],
        horizontal_spacing=0.25,
        vertical_spacing=0.1
    )

    top_features_dict = {}
    max_weight = 0

    for topic_idx, topic in enumerate(model.components_):
        top_features_ind = topic.argsort()[::-1][:n_top_words]
        top_features = feature_names[top_features_ind]
        weights = topic[top_features_ind] / np.sum(topic) * 100
        top_features_dict[topic_idx] = {"features": list(top_features), "weights": list(np.round(weights, 4))}

        max_weight = max(max_weight, weights.max()) 

        # subplot position
        r = topic_idx // cols + 1
        c = topic_idx % cols + 1

        fig.add_trace(
            go.Bar(
                x=weights,
                y=top_features,
                orientation="h",
                marker=dict(color=bar_color, line=dict(color="white", width=1)),
                text=[f"{w:.2f}" for w in weights],
                textposition="outside",
                hovertemplate="<b>%{y}</b><br>weight=%{x:.2f}%<extra></extra>",
                showlegend=False
            ),
            row=r, col=c
        )

        # nicer y ordering (largest at top)
        fig.update_yaxes(autorange="reversed", row=r, col=c)

    # Set x-axis range with padding for all subplots
    for r_idx in range(1, rows + 1):
        for c_idx in range(1, cols + 1):
            fig.update_xaxes(
                range=[0, max_weight * 1.25],  # Add 25% padding for text labels
                row=r_idx, 
                col=c_idx
            )

    # Axes labels for the bottom row
    for c in range(1, cols + 1):
        fig.update_xaxes(title_text="Relative Weight (%)", row=rows, col=c)

    fig.update_layout(
        title=f"<b>{title}</b>",
        height=max(350, 330 * rows),
        margin=dict(l=50, r=20, t=60, b=60)
    )

    st.session_state.top_topics = top_features_dict
    
    return fig