import numpy as np import streamlit as st import plotly.graph_objects as go # type: ignore from plotly.subplots import make_subplots # type: ignore from sklearn.decomposition import NMF # --------- Plot helper (Plotly) --------- def _build_topic_figure( model: NMF, feature_names: np.ndarray, n_top_words: int, title: str, n_components: int, bar_color: str ) -> go.Figure: """Create a Plotly subplot grid of top terms per topic (horizontal bars).""" # Layout: up to 2 columns, as many rows as needed cols = 2 if n_components > 3 else 1 rows = int(np.ceil(n_components / cols)) fig = make_subplots( rows=rows, cols=cols, subplot_titles=[f"Topic {i+1}" for i in range(n_components)], horizontal_spacing=0.25, vertical_spacing=0.1 ) top_features_dict = {} max_weight = 0 for topic_idx, topic in enumerate(model.components_): top_features_ind = topic.argsort()[::-1][:n_top_words] top_features = feature_names[top_features_ind] weights = topic[top_features_ind] / np.sum(topic) * 100 top_features_dict[topic_idx] = {"features": list(top_features), "weights": list(np.round(weights, 4))} max_weight = max(max_weight, weights.max()) # subplot position r = topic_idx // cols + 1 c = topic_idx % cols + 1 fig.add_trace( go.Bar( x=weights, y=top_features, orientation="h", marker=dict(color=bar_color, line=dict(color="white", width=1)), text=[f"{w:.2f}" for w in weights], textposition="outside", hovertemplate="%{y}
weight=%{x:.2f}%", showlegend=False ), row=r, col=c ) # nicer y ordering (largest at top) fig.update_yaxes(autorange="reversed", row=r, col=c) # Set x-axis range with padding for all subplots for r_idx in range(1, rows + 1): for c_idx in range(1, cols + 1): fig.update_xaxes( range=[0, max_weight * 1.25], # Add 25% padding for text labels row=r_idx, col=c_idx ) # Axes labels for the bottom row for c in range(1, cols + 1): fig.update_xaxes(title_text="Relative Weight (%)", row=rows, col=c) fig.update_layout( title=f"{title}", height=max(350, 330 * rows), margin=dict(l=50, r=20, t=60, b=60) ) st.session_state.top_topics = top_features_dict return fig