File size: 4,907 Bytes
5d4981c
 
 
 
 
 
 
 
 
 
 
 
3001d07
5d4981c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import NMF
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline, Pipeline
from utils.build_plotly import _build_topic_figure
from utils.load_data import get_data_directory

import plotly.graph_objects as go # type: ignore

import streamlit as st
from nltk.corpus import stopwords  # type: ignore

from utils.remove_html import remove_html_tags 

# --------- Defaults / Paths ---------
# ROOT = Path(__file__).resolve().parents[1]
# DEFAULT_DATA_DIR = ROOT / "review_data"

DEFAULT_DATA_DIR = get_data_directory()

COLOR_WHEEL = {
    "All_Beauty": "#d946ef",            # magenta-ish
    "Appliances": "#800000",            # maroon
    "Baby_Products": "#87ceeb",         # skyblue
    "Electronics": "#ffd700",           # gold
    "Health_and_Household": "#3cb371",  # mediumseagreen
    "Movies_and_TV": "#663399"          # rebeccapurple
}

# Build stopword list (don’t mutate across calls)
BASE_STOPWORDS = set(stopwords.words("english"))
CUSTOM_KEEP = {
    'not','no','but','ain','don',"don't",'aren',"aren't",'couldn',"couldn't",
    'didn',"didn't",'doesn',"doesn't",'hadn',"hadn't",'hasn',"hasn't",'haven',
    "haven't",'isn',"isn't",'mightn',"mightn't",'mustn',"mustn't",'needn',
    "needn't",'shan',"shan't",'shouldn',"shouldn't",'wasn',"wasn't",'weren',
    "weren't",'won',"won't",'wouldn',"wouldn't",'very','too'
}
DEFAULT_STOPWORDS = sorted(list(BASE_STOPWORDS - CUSTOM_KEEP))


# --------- Data loading / modeling ---------
def _load_category_df(
    data_dir: Path | str,
    category: str,
    lemmatize: bool,
    nrows: int
) -> pd.DataFrame:
    """Load parquet for category; choose lemma or raw; basic cleaning."""
    data_dir = Path(data_dir)
    path = data_dir / f"{category}.parquet"
    lemma_path = data_dir / f"lemma_data/{category}.parquet"

    if lemmatize:
        df = pd.read_parquet(lemma_path)
    else:
        df = pd.read_parquet(path)
        if "text" in df.columns:
            df["text"] = df["text"].astype(str).str.strip().apply(remove_html_tags)

    return df.iloc[:nrows, :].copy()


#@st.cache_data(show_spinner="One moment please!", show_time=True)
def make_topics(
    category: str,
    topic_columns: str,
    lemmatize: bool,
    n1: int,
    n2: int,
    n_components: int,
    rating: Optional[List[int]] = None,
    helpful_vote: Optional[int] = None,
    new_words: Optional[List[str]] = None,
    n_top_words: int = 5,
    data_dir: Optional[str | Path] = None,
    nrows: int = 10_000
) -> Tuple[ColumnTransformer | Pipeline, go.Figure]:
    """
    Fit TF-IDF + NMF topic model and return (pipeline, Plotly figure).

    Returns:
        (topic_pipeline, fig)
    """
    data_dir = data_dir or DEFAULT_DATA_DIR
    df = _load_category_df(data_dir, category, lemmatize, nrows=nrows)

    # Optional filters
    if rating is not None and "rating" in df.columns:
        df = df[df["rating"].isin(rating)]
    if helpful_vote is not None and "helpful_vote" in df.columns:
        df = df[df["helpful_vote"] > helpful_vote]

    # Columns to model
    topic_columns = (topic_columns or "").strip().lower()
    # Make a fresh stopword list each call to avoid global mutation
    stop_list = list(DEFAULT_STOPWORDS)
    if new_words:
        stop_list.extend(new_words)

    tfidf_text = TfidfVectorizer(stop_words=stop_list, ngram_range=(n1, n2))
    tfidf_title = TfidfVectorizer(stop_words=stop_list, ngram_range=(n1, n2))

    if topic_columns == "both":
        preprocessor = ColumnTransformer([
            ("title", tfidf_title, "title"),
            ("text", tfidf_text, "text")
        ])
    elif topic_columns == "text":
        preprocessor = ColumnTransformer([("text", tfidf_text, "text")])
    else:
        # default to title if not 'both' or 'text'
        preprocessor = ColumnTransformer([("title", tfidf_title, "title")])

    nmf = NMF(
        n_components=n_components,
        init="nndsvda",
        solver="mu",
        beta_loss=1,
        random_state=10
    )

    topic_pipeline = make_pipeline(preprocessor, nmf)
    # Fit on only the columns the preprocessor expects
    fit_cols = [c for c in ["title", "text"] if c in df.columns]
    topic_pipeline.fit(df[fit_cols])

    feature_names = topic_pipeline[0].get_feature_names_out()
    nmf_model: NMF = topic_pipeline[1]

    # Choose color from map (fallback if category label differs)
    bar_color = COLOR_WHEEL.get(category, "#184A90")

    fig = _build_topic_figure(
        model=nmf_model,
        feature_names=feature_names,
        n_top_words=n_top_words,
        title=category,
        n_components=n_components,
        bar_color=bar_color
    )

    return topic_pipeline, fig