Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import streamlit as st | |
| import plotly.graph_objects as go # type: ignore | |
| from plotly.subplots import make_subplots # type: ignore | |
| from sklearn.decomposition import NMF | |
| # --------- Plot helper (Plotly) --------- | |
| def _build_topic_figure( | |
| model: NMF, | |
| feature_names: np.ndarray, | |
| n_top_words: int, | |
| title: str, | |
| n_components: int, | |
| bar_color: str | |
| ) -> go.Figure: | |
| """Create a Plotly subplot grid of top terms per topic (horizontal bars).""" | |
| # Layout: up to 2 columns, as many rows as needed | |
| cols = 2 if n_components > 3 else 1 | |
| rows = int(np.ceil(n_components / cols)) | |
| fig = make_subplots( | |
| rows=rows, | |
| cols=cols, | |
| subplot_titles=[f"Topic {i+1}" for i in range(n_components)], | |
| horizontal_spacing=0.25, | |
| vertical_spacing=0.1 | |
| ) | |
| top_features_dict = {} | |
| max_weight = 0 | |
| for topic_idx, topic in enumerate(model.components_): | |
| top_features_ind = topic.argsort()[::-1][:n_top_words] | |
| top_features = feature_names[top_features_ind] | |
| weights = topic[top_features_ind] / np.sum(topic) * 100 | |
| top_features_dict[topic_idx] = {"features": list(top_features), "weights": list(np.round(weights, 4))} | |
| max_weight = max(max_weight, weights.max()) | |
| # subplot position | |
| r = topic_idx // cols + 1 | |
| c = topic_idx % cols + 1 | |
| fig.add_trace( | |
| go.Bar( | |
| x=weights, | |
| y=top_features, | |
| orientation="h", | |
| marker=dict(color=bar_color, line=dict(color="white", width=1)), | |
| text=[f"{w:.2f}" for w in weights], | |
| textposition="outside", | |
| hovertemplate="<b>%{y}</b><br>weight=%{x:.2f}%<extra></extra>", | |
| showlegend=False | |
| ), | |
| row=r, col=c | |
| ) | |
| # nicer y ordering (largest at top) | |
| fig.update_yaxes(autorange="reversed", row=r, col=c) | |
| # Set x-axis range with padding for all subplots | |
| for r_idx in range(1, rows + 1): | |
| for c_idx in range(1, cols + 1): | |
| fig.update_xaxes( | |
| range=[0, max_weight * 1.25], # Add 25% padding for text labels | |
| row=r_idx, | |
| col=c_idx | |
| ) | |
| # Axes labels for the bottom row | |
| for c in range(1, cols + 1): | |
| fig.update_xaxes(title_text="Relative Weight (%)", row=rows, col=c) | |
| fig.update_layout( | |
| title=f"<b>{title}</b>", | |
| height=max(350, 330 * rows), | |
| margin=dict(l=50, r=20, t=60, b=60) | |
| ) | |
| st.session_state.top_topics = top_features_dict | |
| return fig |