Spaces:
Sleeping
Sleeping
| import marimo | |
| __generated_with = "0.14.13" | |
| app = marimo.App( | |
| width="full", | |
| app_title="Text classification using Logistic Regression", | |
| ) | |
| with app.setup: | |
| import glob | |
| import altair as alt | |
| import eli5 | |
| import marimo as mo | |
| import numpy as np | |
| import pandas as pd | |
| from eli5 import format_as_html | |
| from sklearn.calibration import calibration_curve | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import ( | |
| brier_score_loss, | |
| classification_report, | |
| confusion_matrix, | |
| ) | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import label_binarize | |
| def _(): | |
| mo.md( | |
| r""" | |
| # テキスト分類モデルの可視化と解釈 | |
| このノートブックでは、テキスト分類モデルの学習と解釈方法をインタラクティブに探求します。 | |
| [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#logisticregression)のロジスティック回帰モデルを使用し、ELI5ライブラリで予測の説明を可視化します。 | |
| """ | |
| ) | |
| return | |
| def _(): | |
| mo.md( | |
| r""" | |
| ## ロジスティック回帰の内部(テキスト分類向け) | |
| **二値分類(クラスが2つ)** | |
| $$ | |
| s=\mathbf{w}^\top\mathbf{x}+b | |
| =\sum_{i=1}^{d} w_i\,x_i + b | |
| = w_1x_1+w_2x_2+\cdots+w_dx_d+b. | |
| $$ | |
| **多クラス(クラス $c$ のスコア)** | |
| $$ | |
| s_c=\mathbf{w}_c^\top\mathbf{x}+b_c | |
| =\sum_{i=1}^{d} w_{c,i}\,x_i + b_c | |
| = w_{c,1}x_1+w_{c,2}x_2+\cdots+w_{c,d}x_d+b_c. | |
| $$ | |
| **全クラス同時表示(行列形)** | |
| $$ | |
| \mathbf{s}=W\mathbf{x}+\mathbf{b},\quad | |
| W\in\mathbb{R}^{K\times d},\ \mathbf{s}\in\mathbb{R}^{K}. | |
| $$ | |
| **テキスト分類の対応づけ(例)** | |
| $$ | |
| x_i=\mathrm{tfidf}(\text{term}_i,\ \text{doc}),\qquad | |
| \text{contrib}_{i\to c}=x_i\,w_{c,i}. | |
| $$ | |
| **二値の確率化(シグモイド)** | |
| $$ | |
| p(y=1\mid \mathbf{x})=\sigma(s)=\frac{1}{1+e^{-s}},\quad | |
| \log\frac{p}{1-p}=s | |
| $$ | |
| **多クラスの確率化(ソフトマックス)** | |
| $$ | |
| p(y=c\mid \mathbf{x})=\frac{e^{s_c}}{\sum_{k} e^{s_k}},\quad | |
| s_c=\mathbf{w}_c^\top \mathbf{x}+b_c | |
| $$ | |
| **決定境界と閾値** | |
| 二値では$s=0$が閾値(通常は$p=0.5$)。 | |
| 多クラスでは$\{s_c\}$の最大を選ぶ。 | |
| **学習と正則化** | |
| $$ | |
| \min_{\{\mathbf{w}_c,b_c\}} | |
| \left[-\sum_{i}\log p(y_i\mid \mathbf{x}_i)\right] | |
| +\frac{\lambda}{2}\sum_{c}\|\mathbf{w}_c\|_2^2 | |
| $$ | |
| scikit-learnの$C$は$\lambda$の逆数に相当($C$が小さければ$\Rightarrow$正則化強)。 | |
| """ | |
| ) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## データセットの確認 | |
| 4つの作家の小説から構成される小さなテキストコーパスです。 | |
| 各作品からは先頭100トークンのチャンク100個があり、トークンはレマ化されています。 | |
| - A. Merritt: The Moon Pool (Science fiction) | |
| - E. R. Eddison: The Worm Ouroboros (Fantasy) | |
| - H. G. Wells: The Wonderful Visit (Fantasy) | |
| - Mark Twain: A Connecticut Yankee in King Arthur's Court (Historical fiction/Fantasy) | |
| (StandardEbooksより) | |
| """ | |
| ) | |
| return | |
| def _(): | |
| rng = np.random.default_rng(42) | |
| df = pd.DataFrame(columns=["text", "label"]) | |
| for csv_file in glob.glob("./*.csv"): | |
| d = pd.read_csv(csv_file) | |
| df = pd.concat([df, d]).reset_index(drop=True) | |
| mo.vstack([mo.md("### コーパス"), df]) | |
| return (df,) | |
| def _(): | |
| mo.md( | |
| """ | |
| ## ハイパーパラメータの調整 | |
| 以下のスライダーとドロップダウンでモデルのハイパーパラメータを調整できます: | |
| - **テスト比率**: 訓練データとテストデータの分割比率 | |
| - **最大n-gram**: 使用するn-gramの最大長(1=単語、2=連続する2単語、など) | |
| - **最小出現回数**: 特徴量として扱う単語の最小出現回数 | |
| - **正則化C**: ロジスティック回帰の正則化強度(小さいほど強い正則化) | |
| """ | |
| ) | |
| return | |
| def _(): | |
| split = mo.ui.slider(0.1, 0.9, value=0.3, step=0.05, label="テスト比率") | |
| max_ng = mo.ui.slider(1, 4, value=2, step=1, label="最大n-gram") | |
| min_df = mo.ui.slider(1, 10, value=1, step=1, label="最小出現回数") | |
| C_pick = mo.ui.dropdown(options=[0.01, 0.1, 1.0, 10.0], value=1.0, label="正則化C") | |
| k_top = mo.ui.slider(5, 40, value=20, step=5, label="表示上位語数") | |
| cls_for_cal = mo.ui.dropdown(options=["All"], value="All", label="クラス(信頼性曲線)") | |
| mo.vstack([split, max_ng, min_df, C_pick, k_top, cls_for_cal]) | |
| return C_pick, cls_for_cal, max_ng, min_df, split | |
| def _(C_pick, df, max_ng, min_df, split): | |
| for _ in mo.status.progress_bar( | |
| range(10), | |
| title="Training Logistic Regression model", | |
| subtitle="Please wait", | |
| show_eta=True, | |
| show_rate=True | |
| ): | |
| X_train_text, X_test_text, y_train, y_test = train_test_split( | |
| df["text"].tolist(), | |
| df["label"].tolist(), | |
| test_size=split.value, | |
| random_state=0, | |
| stratify=df["label"].tolist(), | |
| ) | |
| vec = TfidfVectorizer(ngram_range=(1, max_ng.value), min_df=min_df.value) | |
| X_train = vec.fit_transform(X_train_text) | |
| X_test = vec.transform(X_test_text) | |
| clf = LogisticRegression( | |
| C=float(C_pick.value), max_iter=2000, solver="lbfgs" | |
| ) | |
| clf.fit(X_train, y_train) | |
| classes = clf.classes_ | |
| return X_test, X_test_text, X_train, classes, clf, vec, y_test, y_train | |
| def _(X_test, classes, clf, vec, y_test): | |
| feature_names = vec.get_feature_names_out() | |
| W = clf.coef_ | |
| b = clf.intercept_ | |
| weights_df = ( | |
| pd.DataFrame(W, index=classes, columns=feature_names) | |
| .stack() | |
| .rename("weight") | |
| .reset_index() | |
| .rename(columns={"level_0": "class", "level_1": "term"}) | |
| ) | |
| y_pred = clf.predict(X_test) | |
| y_prob = clf.predict_proba(X_test) | |
| probs_df = pd.DataFrame(y_prob, columns=classes) | |
| cm = ( | |
| pd.DataFrame( | |
| confusion_matrix(y_test, y_pred, labels=classes), | |
| index=classes, | |
| columns=classes, | |
| ) | |
| .reset_index() | |
| .melt(id_vars="index", var_name="pred", value_name="count") | |
| .rename(columns={"index": "true"}) | |
| ) | |
| heat = ( | |
| alt.Chart(cm) | |
| .mark_rect() | |
| .encode( | |
| x="pred:N", | |
| y="true:N", | |
| color=alt.Color("count:Q", scale=alt.Scale(scheme="blues")), | |
| ) | |
| .properties(title="混同行列", width=300, height=300) | |
| ) | |
| text = ( | |
| alt.Chart(cm) | |
| .mark_text(baseline="middle") | |
| .encode( | |
| x="pred:N", | |
| y="true:N", | |
| text="count:Q", | |
| color=alt.condition( | |
| alt.datum.count > 5, | |
| alt.value("white"), | |
| alt.value("black") | |
| ) | |
| ) | |
| ) | |
| cm_chart = alt.layer(heat, text).resolve_scale(color="independent") | |
| mo.ui.altair_chart(cm_chart) | |
| return weights_df, y_pred | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 重みと文書頻度の関係 | |
| 各クラスに対する特徴語(単語/n-gram)の重みを可視化します。 | |
| 横軸は文書頻度の対数、縦軸は重みの大きさ(または符号付き)を表します。 | |
| これは、頻度の高い単語ほど重みが大きいわけではないことを示しています。 | |
| """ | |
| ) | |
| return | |
| def _(X_train, classes, vec, weights_df): | |
| df_counts = pd.Series((X_train>0).sum(axis=0).A1, index=vec.get_feature_names_out(), name="doc_freq").reset_index().rename(columns={"index":"term"}) | |
| wdf = weights_df.merge(df_counts, on="term", how="left") | |
| wdf["doc_freq"] = wdf["doc_freq"].fillna(0).astype(int) | |
| wdf["log_df"] = np.log10(wdf["doc_freq"]+1) | |
| w_cls_pick = mo.ui.dropdown(options=list(classes), value=list(classes)[0], label="クラス") | |
| abs_toggle = mo.ui.switch(label="絶対値で表示", value=True) | |
| mo.hstack([w_cls_pick, abs_toggle]) | |
| return abs_toggle, w_cls_pick, wdf | |
| def _(abs_toggle, w_cls_pick, wdf): | |
| sub = wdf[wdf["class"]==w_cls_pick.value].copy() | |
| sub["y"] = sub["weight"].abs() if abs_toggle.value else sub["weight"] | |
| w_chart = alt.Chart(sub).mark_text().encode( | |
| x=alt.X("log_df:Q", title="log10(文書頻度+1)"), | |
| y=alt.Y("y:Q", title="重み"), | |
| text="term:N", | |
| tooltip=["term", "doc_freq", "weight"] | |
| ).properties(width=500, height=360, title=f"重み{'(絶対値)' if abs_toggle.value else ''}と出現頻度の関係") | |
| mo.ui.altair_chart(w_chart) | |
| return | |
| def _(): | |
| topK = mo.ui.slider(10, 80, value=30, step=5, label="対象語数") | |
| mo.hstack([topK]) | |
| return (topK,) | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 共起相関行列 | |
| 選択したクラスについて、上位の特徴語同士の共起相関をヒートマップで表示します。 | |
| 赤が正の相関(一緒に出現しやすい)、青が負の相関(同時に出現しにくい)です。 | |
| これにより、どの単語がセットでモデルに影響を与えているかを理解できます。 | |
| """ | |
| ) | |
| return | |
| def _(X_train, topK, vec, w_cls_pick, weights_df): | |
| wabs = (weights_df.assign(absw=weights_df["weight"].abs()) | |
| .sort_values(["class","absw"], ascending=[True, False])) | |
| terms_ = wabs[wabs["class"]==w_cls_pick.value].head(topK.value)["term"].tolist() | |
| indices = pd.Series(range(len(vec.get_feature_names_out())), index=vec.get_feature_names_out()) | |
| cols = [indices[t] for t in terms_ if t in indices] | |
| Xm = (X_train[:, cols] > 0).astype("float32") | |
| C_ = (Xm.T @ Xm).A | |
| n = Xm.shape[0] | |
| p = Xm.mean(axis=0).A1 | |
| std = np.sqrt(p*(1-p)+1e-9) | |
| corr = (C_/n - np.outer(p,p)) / (np.outer(std,std)+1e-9) | |
| corr_df = pd.DataFrame(corr, index=terms_, columns=terms_).stack().rename("corr").reset_index().rename(columns={"level_0":"t1","level_1":"t2"}) | |
| heat_ = alt.Chart(corr_df).mark_rect().encode( | |
| x=alt.X("t1:N", sort=terms_), | |
| y=alt.Y("t2:N", sort=terms_), | |
| color=alt.Color("corr:Q", scale=alt.Scale(scheme="redblue", domain=[-1,1])), | |
| tooltip=["t1","t2","corr"] | |
| ).properties(width=420, height=420, title=f"共起相関({w_cls_pick.value} 上位語)") | |
| mo.ui.altair_chart(heat_) | |
| return | |
| def _(): | |
| # cls_select = mo.ui.dropdown( | |
| # options=list(classes), value=list(classes)[0], label="クラス(重み)" | |
| # ) | |
| k_slider = mo.ui.slider(5, 40, value=5, step=5, label="上位語数") | |
| # mo.hstack([cls_select, k_slider]) | |
| mo.hstack([k_slider]) | |
| return (k_slider,) | |
| def _(): | |
| # dfc = weights_df[weights_df["class"] == cls_select.value] | |
| # top_pos = dfc.nlargest(k_slider.value, "weight") | |
| # top_neg = dfc.nsmallest(k_slider.value, "weight") | |
| # pos_chart = ( | |
| # alt.Chart(top_pos.assign(dir="pos")) | |
| # .mark_bar() | |
| # .encode(x=alt.X("weight:Q", title="重み"), y=alt.Y("term:N", sort="-x")) | |
| # .properties(width=420, height=420, title="正の寄与") | |
| # ) | |
| # neg_chart = ( | |
| # alt.Chart(top_neg.assign(dir="neg")) | |
| # .mark_bar() | |
| # .encode(x=alt.X("weight:Q", title="重み"), y=alt.Y("term:N", sort="x")) | |
| # .properties(width=420, height=420, title="負の寄与") | |
| # ) | |
| # mo.hstack([mo.ui.altair_chart(pos_chart), mo.ui.altair_chart(neg_chart)]) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## クラスごとの特徴語 | |
| 各クラスで最も影響力の大きい(正と負の)特徴語を表示します。 | |
| 棒グラフで重みが大きい単語が分かりやすく可視化され、どの単語がクラスに寄与しているかを確認できます。 | |
| """ | |
| ) | |
| return | |
| def _(classes, clf, k_slider, vec): | |
| html_global = format_as_html( | |
| eli5.explain_weights(clf, vec=vec, target_names=list(classes), top=(k_slider.value, k_slider.value)) | |
| ) | |
| mo.Html(html_global) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 個別予測の説明 | |
| テストデータから特定の文書を選択して、モデルの予測結果を詳しく調べます。 | |
| [ELI5ライブラリ](https://eli5.readthedocs.io/en/latest/overview.html)を使用して、どの単語がどの程度予測に寄与したかを可視化します。 | |
| """ | |
| ) | |
| return | |
| def _(X_test_text): | |
| idx = mo.ui.slider(0, len(X_test_text) - 1, value=0, step=1, label="予測対象文書の選択(インデックス)") | |
| mo.hstack([idx]) | |
| return (idx,) | |
| def _(X_test_text, classes, clf, idx, vec, y_test): | |
| x = X_test_text[idx.value] | |
| proba = clf.predict_proba(vec.transform([x]))[0] | |
| pred = classes[np.argmax(proba)] | |
| table = pd.DataFrame({"class": classes, "prob": proba}).sort_values( | |
| "prob", ascending=False | |
| ) | |
| mo.vstack([mo.md(f"**予測:** {pred} (正解:{y_test[idx.value]})"), mo.ui.table(table)]) | |
| return (x,) | |
| def _(classes, clf, vec, x): | |
| html_local = format_as_html( | |
| eli5.explain_prediction(clf, x, vec=vec, target_names=list(classes), top=20) | |
| ) | |
| mo.Html(html_local) | |
| return | |
| def _(classes, clf, vec, x): | |
| html_cf = format_as_html( | |
| eli5.explain_prediction( | |
| clf, x, vec=vec, target_names=list(classes), top=(20, 20) | |
| ) | |
| ) | |
| mo.Html(html_cf) | |
| return | |
| def _(): | |
| mo.md( | |
| r""" | |
| **How to read these explanations** | |
| - Weights show how the model learned to associate words with a class, but words often occur together. Interpret groups, not single weights. | |
| - Coefficients depend on feature scaling. Compare contributions in a specific text ($\textrm{value} \times \textrm{weight}$), not raw weights across different feature types. | |
| - Rare words can have large weights yet seldom matter. Check frequency vs. weight and the per-example highlights above. | |
| """ | |
| ) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 反事実的分析(What-if分析) | |
| 選択した文書を編集し、テキストの変更が予測確率にどのように影響するかを観察できます。 | |
| これはモデルの動作をより深く理解し、モデルに対する信頼を築くのに役立ちます。 | |
| """ | |
| ) | |
| return | |
| def _(x): | |
| editor = mo.ui.text_area( | |
| label="テキスト編集(反事実)", | |
| value=x, | |
| full_width=True, | |
| ) | |
| editor | |
| return (editor,) | |
| def _(classes, clf, editor, vec, x): | |
| x2 = editor.value | |
| p1 = clf.predict_proba(vec.transform([x]))[0] | |
| p2 = clf.predict_proba(vec.transform([x2]))[0] | |
| delta = pd.DataFrame( | |
| {"class": classes, "before": p1, "after": p2, "diff": p2 - p1} | |
| ).sort_values("after", ascending=False) | |
| mo.ui.table(delta) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 正則化パスの可視化 | |
| 正則化係数Cを変化させたときの特徴語の重みの変化をプロットします。 | |
| これにより、どの特徴がロバストで、どの特徴が過学習の可能性があるかを理解できます。 | |
| """ | |
| ) | |
| return | |
| def _(X_train, vec, y_train): | |
| Cgrid = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0] | |
| rows = [] | |
| terms = [] | |
| if hasattr(vec, "get_feature_names_out"): | |
| terms = list(vec.get_feature_names_out()) | |
| terms_pick = terms[:1000] | |
| pick = terms_pick[:5] | |
| for C in Cgrid: | |
| m = LogisticRegression( | |
| C=C, max_iter=2000, solver="lbfgs" | |
| ).fit(X_train, y_train) | |
| W2 = m.coef_ | |
| fn = vec.get_feature_names_out() | |
| dfW = ( | |
| pd.DataFrame(W2, index=m.classes_, columns=fn) | |
| .stack() | |
| .rename("w") | |
| .reset_index() | |
| .rename(columns={"level_0": "class", "level_1": "term"}) | |
| ) | |
| rows.append(dfW[dfW["term"].isin(pick)].assign(C=C)) | |
| path_df = ( | |
| pd.concat(rows) if rows else pd.DataFrame(columns=["class", "term", "w", "C"]) | |
| ) | |
| chart = ( | |
| alt.Chart(path_df) | |
| .mark_line() | |
| .encode( | |
| x=alt.X("C:Q", scale=alt.Scale(type="log")), | |
| y="w:Q", | |
| color="term:N", | |
| facet="class:N", | |
| ) | |
| .properties(width=260, height=160, title="正則化で重みがどう変わるか") | |
| ) | |
| mo.ui.altair_chart(chart) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 信頼性曲線 | |
| モデルの予測確率がどれだけ信頼できるかを評価するための信頼性曲線を表示します。 | |
| 理想的には45度線に近い方が良く、予測確率が実際の正解率と一致していることを意味します。 | |
| """ | |
| ) | |
| return | |
| def _(X_test, classes, clf, y_test): | |
| probs = clf.predict_proba(X_test) | |
| dfp = pd.DataFrame(probs, columns=classes) | |
| df_true = pd.Series(y_test, name="true") | |
| cls_pick = classes[0] | |
| bins = [] | |
| lines = [] | |
| for cls in classes: | |
| y_bin = (df_true == cls).astype(int).to_numpy() | |
| prob_cls = dfp[cls].to_numpy() | |
| f_obs, f_pred = calibration_curve(y_bin, prob_cls, n_bins=8, strategy="uniform") | |
| bins.append( | |
| pd.DataFrame({"class": cls, "mean_pred": f_pred, "empirical": f_obs}) | |
| ) | |
| cal_df = pd.concat(bins) | |
| base = ( | |
| alt.Chart(cal_df) | |
| .mark_line() | |
| .encode( | |
| x=alt.X("mean_pred:Q", title="平均予測確率 (below x=y is over-confident)"), | |
| y=alt.Y("empirical:Q", title="実測精度 (above x=y is under-confident)"), | |
| color="class:N", | |
| ) | |
| .properties(width=420, height=300, title="信頼性曲線") | |
| .interactive() | |
| ) | |
| diag = ( | |
| alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]})) | |
| .mark_line(strokeDash=[5, 5], color="black") | |
| .encode(x="x:Q", y="y:Q") | |
| .interactive() | |
| ) | |
| combined = base + diag | |
| mo.ui.altair_chart(combined) | |
| return | |
| def _(X_test, classes, clf, cls_for_cal, y_test): | |
| cal_sel_m = cls_for_cal.value if 'cls_for_cal' in globals() else "All" | |
| cal_proba_matrix_m = clf.predict_proba(X_test) | |
| cal_df_proba_m = pd.DataFrame(cal_proba_matrix_m, columns=classes) | |
| cal_true_series_m = pd.Series(y_test, name="true") | |
| if cal_sel_m == "All": | |
| cal_class_list_m = list(classes) | |
| else: | |
| cal_class_list_m = [cal_sel_m] | |
| cal_rows_m = [] | |
| for cal_c_m in cal_class_list_m: | |
| cal_ybin_m = (cal_true_series_m == cal_c_m).astype(int).to_numpy() | |
| cal_p_m = cal_df_proba_m[cal_c_m].to_numpy() | |
| cal_fobs_m, cal_fpred_m = calibration_curve(cal_ybin_m, cal_p_m, n_bins=10, strategy="uniform") | |
| cal_n_m = len(cal_p_m) | |
| cal_bin_m = pd.cut(cal_p_m, bins=np.linspace(0,1,11), right=False, include_lowest=True) | |
| cal_group_m = pd.DataFrame({"bin": cal_bin_m, "p": cal_p_m, "y": cal_ybin_m}).groupby("bin") | |
| cal_ece_m = (cal_group_m.apply(lambda g: abs(g["y"].mean() - g["p"].mean()) * len(g) / cal_n_m)).sum() | |
| cal_rows_m.append({"class": cal_c_m, "ECE": cal_ece_m}) | |
| cal_ece_df_m = pd.DataFrame(cal_rows_m).sort_values("ECE").reset_index(drop=True) | |
| cal_class_show_m = cal_class_list_m[0] | |
| cal_hist_chart_m = alt.Chart(pd.DataFrame({"p": cal_df_proba_m[cal_class_show_m]})).mark_bar().encode( | |
| x=alt.X("p:Q", bin=alt.Bin(maxbins=20), title=f"予測確率 p({cal_class_show_m})"), | |
| y=alt.Y("count()", title="件数") | |
| ).properties(width=360, height=220, title="信頼度ヒストグラム") | |
| mo.hstack([mo.ui.table(cal_ece_df_m), mo.ui.altair_chart(cal_hist_chart_m)]) | |
| return | |
| def _(): | |
| mo.md( | |
| """ | |
| ## 混同行列の詳細分析 | |
| 特定の真のクラスと予測クラスの組み合わせについて、実際の予測例を確認できます。 | |
| これにより、モデルがどのような誤分類をしているかを具体的に観察できます。 | |
| """ | |
| ) | |
| return | |
| def _(X_test_text, classes, y_pred, y_test): | |
| df_show = pd.DataFrame({"text": X_test_text, "true": y_test, "pred": y_pred}) | |
| sel_true = mo.ui.dropdown( | |
| options=list(classes), value=list(classes)[0], label="True" | |
| ) | |
| sel_pred = mo.ui.dropdown( | |
| options=list(classes), value=list(classes)[0], label="Pred" | |
| ) | |
| mo.hstack([sel_true, sel_pred]) | |
| return df_show, sel_pred, sel_true | |
| def _(df_show, sel_pred, sel_true): | |
| subset = df_show[ | |
| (df_show["true"] == sel_true.value) & (df_show["pred"] == sel_pred.value) | |
| ].reset_index(drop=True) | |
| mo.ui.table(subset) | |
| return | |
| def _(): | |
| return | |
| def _(): | |
| return | |
| if __name__ == "__main__": | |
| app.run() | |