from __future__ import annotations from pathlib import Path from typing import List, Optional, Tuple import numpy as np import pandas as pd from sklearn.compose import ColumnTransformer from sklearn.decomposition import NMF from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.pipeline import make_pipeline, Pipeline from utils.build_plotly import _build_topic_figure from utils.load_data import get_data_directory import plotly.graph_objects as go # type: ignore import streamlit as st from nltk.corpus import stopwords # type: ignore from utils.remove_html import remove_html_tags # --------- Defaults / Paths --------- # ROOT = Path(__file__).resolve().parents[1] # DEFAULT_DATA_DIR = ROOT / "review_data" DEFAULT_DATA_DIR = get_data_directory() COLOR_WHEEL = { "All_Beauty": "#d946ef", # magenta-ish "Appliances": "#800000", # maroon "Baby_Products": "#87ceeb", # skyblue "Electronics": "#ffd700", # gold "Health_and_Household": "#3cb371", # mediumseagreen "Movies_and_TV": "#663399" # rebeccapurple } # Build stopword list (don’t mutate across calls) BASE_STOPWORDS = set(stopwords.words("english")) CUSTOM_KEEP = { 'not','no','but','ain','don',"don't",'aren',"aren't",'couldn',"couldn't", 'didn',"didn't",'doesn',"doesn't",'hadn',"hadn't",'hasn',"hasn't",'haven', "haven't",'isn',"isn't",'mightn',"mightn't",'mustn',"mustn't",'needn', "needn't",'shan',"shan't",'shouldn',"shouldn't",'wasn',"wasn't",'weren', "weren't",'won',"won't",'wouldn',"wouldn't",'very','too' } DEFAULT_STOPWORDS = sorted(list(BASE_STOPWORDS - CUSTOM_KEEP)) # --------- Data loading / modeling --------- def _load_category_df( data_dir: Path | str, category: str, lemmatize: bool, nrows: int ) -> pd.DataFrame: """Load parquet for category; choose lemma or raw; basic cleaning.""" data_dir = Path(data_dir) path = data_dir / f"{category}.parquet" lemma_path = data_dir / f"lemma_data/{category}.parquet" if lemmatize: df = pd.read_parquet(lemma_path) else: df = pd.read_parquet(path) if "text" in df.columns: df["text"] = df["text"].astype(str).str.strip().apply(remove_html_tags) return df.iloc[:nrows, :].copy() #@st.cache_data(show_spinner="One moment please!", show_time=True) def make_topics( category: str, topic_columns: str, lemmatize: bool, n1: int, n2: int, n_components: int, rating: Optional[List[int]] = None, helpful_vote: Optional[int] = None, new_words: Optional[List[str]] = None, n_top_words: int = 5, data_dir: Optional[str | Path] = None, nrows: int = 10_000 ) -> Tuple[ColumnTransformer | Pipeline, go.Figure]: """ Fit TF-IDF + NMF topic model and return (pipeline, Plotly figure). Returns: (topic_pipeline, fig) """ data_dir = data_dir or DEFAULT_DATA_DIR df = _load_category_df(data_dir, category, lemmatize, nrows=nrows) # Optional filters if rating is not None and "rating" in df.columns: df = df[df["rating"].isin(rating)] if helpful_vote is not None and "helpful_vote" in df.columns: df = df[df["helpful_vote"] > helpful_vote] # Columns to model topic_columns = (topic_columns or "").strip().lower() # Make a fresh stopword list each call to avoid global mutation stop_list = list(DEFAULT_STOPWORDS) if new_words: stop_list.extend(new_words) tfidf_text = TfidfVectorizer(stop_words=stop_list, ngram_range=(n1, n2)) tfidf_title = TfidfVectorizer(stop_words=stop_list, ngram_range=(n1, n2)) if topic_columns == "both": preprocessor = ColumnTransformer([ ("title", tfidf_title, "title"), ("text", tfidf_text, "text") ]) elif topic_columns == "text": preprocessor = ColumnTransformer([("text", tfidf_text, "text")]) else: # default to title if not 'both' or 'text' preprocessor = ColumnTransformer([("title", tfidf_title, "title")]) nmf = NMF( n_components=n_components, init="nndsvda", solver="mu", beta_loss=1, random_state=10 ) topic_pipeline = make_pipeline(preprocessor, nmf) # Fit on only the columns the preprocessor expects fit_cols = [c for c in ["title", "text"] if c in df.columns] topic_pipeline.fit(df[fit_cols]) feature_names = topic_pipeline[0].get_feature_names_out() nmf_model: NMF = topic_pipeline[1] # Choose color from map (fallback if category label differs) bar_color = COLOR_WHEEL.get(category, "#184A90") fig = _build_topic_figure( model=nmf_model, feature_names=feature_names, n_top_words=n_top_words, title=category, n_components=n_components, bar_color=bar_color ) return topic_pipeline, fig