import numpy as np
import streamlit as st
import plotly.graph_objects as go # type: ignore
from plotly.subplots import make_subplots # type: ignore
from sklearn.decomposition import NMF
# --------- Plot helper (Plotly) ---------
def _build_topic_figure(
model: NMF,
feature_names: np.ndarray,
n_top_words: int,
title: str,
n_components: int,
bar_color: str
) -> go.Figure:
"""Create a Plotly subplot grid of top terms per topic (horizontal bars)."""
# Layout: up to 2 columns, as many rows as needed
cols = 2 if n_components > 3 else 1
rows = int(np.ceil(n_components / cols))
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=[f"Topic {i+1}" for i in range(n_components)],
horizontal_spacing=0.25,
vertical_spacing=0.1
)
top_features_dict = {}
max_weight = 0
for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[::-1][:n_top_words]
top_features = feature_names[top_features_ind]
weights = topic[top_features_ind] / np.sum(topic) * 100
top_features_dict[topic_idx] = {"features": list(top_features), "weights": list(np.round(weights, 4))}
max_weight = max(max_weight, weights.max())
# subplot position
r = topic_idx // cols + 1
c = topic_idx % cols + 1
fig.add_trace(
go.Bar(
x=weights,
y=top_features,
orientation="h",
marker=dict(color=bar_color, line=dict(color="white", width=1)),
text=[f"{w:.2f}" for w in weights],
textposition="outside",
hovertemplate="%{y}
weight=%{x:.2f}%",
showlegend=False
),
row=r, col=c
)
# nicer y ordering (largest at top)
fig.update_yaxes(autorange="reversed", row=r, col=c)
# Set x-axis range with padding for all subplots
for r_idx in range(1, rows + 1):
for c_idx in range(1, cols + 1):
fig.update_xaxes(
range=[0, max_weight * 1.25], # Add 25% padding for text labels
row=r_idx,
col=c_idx
)
# Axes labels for the bottom row
for c in range(1, cols + 1):
fig.update_xaxes(title_text="Relative Weight (%)", row=rows, col=c)
fig.update_layout(
title=f"{title}",
height=max(350, 330 * rows),
margin=dict(l=50, r=20, t=60, b=60)
)
st.session_state.top_topics = top_features_dict
return fig