import streamlit as st import numpy as np import pandas as pd import json import matplotlib.pyplot as plt import seaborn as sns from typing import List, Dict, Any, Union import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import shap st.set_page_config( page_title="Text Classifiers", layout="wide", initial_sidebar_state="expanded" ) from text_preprocessing import ( preprocess_text, get_contextual_embeddings, TextVectorizer ) from classical_classifiers import ( get_logistic_regression, get_svm_linear, get_random_forest, get_gradient_boosting, get_voting_classifier ) from neural_classifiers import get_transformer_classifier from model_evaluation import evaluate_model from model_interpretation import ( get_linear_feature_importance, analyze_errors, get_transformer_attention, visualize_attention_weights, get_token_importance_captum, plot_token_importance ) import warnings warnings.filterwarnings("ignore") if 'models' not in st.session_state: st.session_state.models = {} if 'results' not in st.session_state: st.session_state.results = {} if 'dataset' not in st.session_state: st.session_state.dataset = None if 'task_type' not in st.session_state: st.session_state.task_type = None if 'preprocessed' not in st.session_state: st.session_state.preprocessed = None if 'X' not in st.session_state: st.session_state.X = None if 'y' not in st.session_state: st.session_state.y = None if 'feature_names' not in st.session_state: st.session_state.feature_names = None if 'vectorizer' not in st.session_state: st.session_state.vectorizer = None if 'vectorizer_type' not in st.session_state: st.session_state.vectorizer_type = None if 'X_test' not in st.session_state: st.session_state.X_test = None if 'y_test' not in st.session_state: st.session_state.y_test = None if 'test_texts' not in st.session_state: st.session_state.test_texts = None if 'label_encoder' not in st.session_state: st.session_state.label_encoder = None if 'rubert_model' not in st.session_state: st.session_state.rubert_model = None if 'rubert_tokenizer' not in st.session_state: st.session_state.rubert_tokenizer = None if 'rubert_trained' not in st.session_state: st.session_state.rubert_trained = False st.sidebar.title("Setup") st.sidebar.subheader("1. Upload Dataset (JSONL)") uploaded_file = st.sidebar.file_uploader("Upload .jsonl file", type=["jsonl"]) if uploaded_file: try: raw_data = [] lines = uploaded_file.getvalue().decode("utf-8").splitlines() for line in lines: if line.strip(): raw_data.append(json.loads(line)) st.session_state.dataset = raw_data first = raw_data[0] if 'sentiment' in first: st.session_state.task_type = "binary" labels = [item['sentiment'] for item in raw_data] elif 'category' in first: st.session_state.task_type = "multiclass" labels = [item['category'] for item in raw_data] elif 'tags' in first: st.session_state.task_type = "multilabel" labels = [item['tags'] for item in raw_data] else: st.sidebar.error("No label field found") st.session_state.task_type = None st.session_state.dataset = None if st.session_state.task_type: st.sidebar.success(f"Loaded {len(raw_data)} samples. Task: {st.session_state.task_type}") if st.session_state.task_type == "binary": id2label = {0: "Negative", 1: "Positive"} label2id = {"Negative": 0, "Positive": 1} elif st.session_state.task_type == "multiclass": id2label = {0: "Политика", 1: "Экономика", 2: "Спорт", 3: "Культура"} label2id = {"Политика": 0, "Экономика": 1, "Спорт": 2, "Культура": 3} else: id2label = None label2id = None st.session_state.id2label = id2label st.session_state.label2id = label2id except Exception as e: st.sidebar.error(f"Failed to parse JSONL: {e}") st.session_state.dataset = None if st.session_state.dataset is not None: st.sidebar.subheader("2. Preprocess Text") lang = st.sidebar.selectbox("Language", ["ru", "en"], index=0) st.session_state.preprocess_lang = 'ru' if st.sidebar.button("Run Preprocessing"): with st.spinner("Preprocessing..."): texts = [item['text'] for item in st.session_state.dataset] preprocessed = [preprocess_text(text, lang='ru', remove_stopwords=False) for text in texts] st.session_state.preprocessed = preprocessed st.sidebar.success("Preprocessing done!") if st.session_state.preprocessed is not None: st.sidebar.subheader("3. Vectorization (Classical)") vectorizer_type = st.sidebar.selectbox("Method", ["TF-IDF", "RuBERT Embeddings"]) if st.sidebar.button("Vectorize"): with st.spinner("Vectorizing..."): if vectorizer_type == "TF-IDF": vectorizer = TextVectorizer() if not isinstance(st.session_state.preprocessed[0], str): st.session_state.preprocessed = [ ' '.join(text) for text in st.session_state.preprocessed ] st.sidebar.write("Using max_features=5000") X = vectorizer.tfidf(st.session_state.preprocessed, max_features=5000) st.sidebar.write(f"X shape: {X.shape}") st.session_state.vectorizer = vectorizer st.session_state.feature_names = vectorizer.tfidf_vectorizer.get_feature_names_out() else: X = [] for text in st.session_state.preprocessed: emb = get_contextual_embeddings([text], model_name="DeepPavlov/rubert-base-cased") X.append(emb[0]) X = np.array(X) st.session_state.vectorizer = None st.session_state.feature_names = None st.session_state.X = X st.session_state.vectorizer_type = vectorizer_type if st.session_state.task_type == "binary": y = np.array([item['sentiment'] for item in st.session_state.dataset]) elif st.session_state.task_type == "multiclass": y = np.array([item['category'] for item in st.session_state.dataset]) else: y = [item['tags'] for item in st.session_state.dataset] st.session_state.y = y st.sidebar.success("Vectorization complete!") if st.session_state.X is not None: st.sidebar.subheader("4. Train Classical Models") model_options = ["Logistic Regression", "SVM", "Random Forest", "XGBoost", "Voting"] selected_models = st.sidebar.multiselect("Models", model_options) if st.sidebar.button("Train Classical Models"): from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder X = st.session_state.X y = st.session_state.y if st.session_state.task_type == "multiclass": le = LabelEncoder() y_encoded = le.fit_transform(y) st.session_state.label_encoder = le y_for_split = y_encoded else: y_for_split = y if st.session_state.task_type == "binary" else np.array([len(tags) for tags in y]) if st.session_state.task_type == "multilabel": split_idx = int(0.8 * len(X)) X_train, X_test = X[:split_idx], X[split_idx:] y_train, y_test = y[:split_idx], y[split_idx:] test_texts = [item['text'] for item in st.session_state.dataset[split_idx:]] else: indices = np.arange(len(X)) X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split( X, y_for_split, indices, test_size=0.2, stratify=y_for_split if st.session_state.task_type != "multilabel" else None, random_state=42 ) test_texts = [st.session_state.dataset[i]['text'] for i in idx_test] if st.session_state.task_type == "multiclass": y_train = le.inverse_transform(y_train) y_test = le.inverse_transform(y_test) st.session_state.X_test = X_test st.session_state.y_test = y_test st.session_state.test_texts = test_texts for name in selected_models: try: with st.spinner(f"Training {name}..."): if name == "Logistic Regression": model = get_logistic_regression() model.fit(X_train, y_train) st.session_state.models[name] = model elif name == "SVM": model = get_svm_linear() model.fit(X_train, y_train) st.session_state.models[name] = model elif name == "Random Forest": model = get_random_forest() model.fit(X_train, y_train) st.session_state.models[name] = model elif name == "XGBoost": model = get_gradient_boosting("xgb", n_estimators=100) model.fit(X_train, y_train) st.session_state.models[name] = model elif name == "Voting": model = get_voting_classifier() model.fit(X_train, y_train) st.session_state.models[name] = model if st.session_state.task_type != "multilabel": metrics = evaluate_model(model, X_test, y_test) st.session_state.results[name] = metrics except Exception as e: st.sidebar.error(f"Failed to train {name}: {e}") continue st.sidebar.success("Classical models trained!") if st.session_state.dataset is not None and st.session_state.task_type in ["binary", "multiclass"]: st.sidebar.subheader("5. Train RuBERT (Transformer)") if st.sidebar.button("Train RuBERT"): with st.spinner("Loading RuBERT..."): try: from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig num_labels = 2 if st.session_state.task_type == "binary" else 4 model_name = "DeepPavlov/rubert-base-cased" config = AutoConfig.from_pretrained( model_name, num_labels=num_labels, id2label=st.session_state.id2label, label2id=st.session_state.label2id ) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) st.session_state.rubert_model = model st.session_state.rubert_tokenizer = tokenizer st.session_state.rubert_trained = True st.sidebar.success("RuBERT loaded with correct labels!") except Exception as e: st.sidebar.error(f"RuBERT loading failed: {e}") st.exception(e) st.title("Text Classifiers") tab1, tab2, tab3, tab4 = st.tabs([ "Classify", "Interpret", "Compare", "Error Analysis" ]) with tab1: st.subheader("Classify New Text") input_text = st.text_area("Enter text", "Сегодня прошёл важный матч по хоккею.") if st.button("Classify"): cols = st.columns(2) with cols[0]: st.markdown("### Classical Models") if not st.session_state.models: st.info("No classical models trained") else: tokens = preprocess_text(input_text, lang='ru', remove_stopwords=False) preprocessed = " ".join(tokens) if st.session_state.vectorizer_type == "TF-IDF": X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray() else: X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased") for name, model in st.session_state.models.items(): pred = model.predict(X_input)[0] st.write(f"**{name}**: {pred}") if hasattr(model, "predict_proba"): proba = model.predict_proba(X_input)[0] st.write(f"Probabilities: {dict(zip(model.classes_, proba))}") with cols[1]: st.markdown("### RuBERT") if not st.session_state.rubert_trained: st.info("Train RuBERT in sidebar") else: try: from transformers import pipeline pipe = pipeline( "text-classification", model=st.session_state.rubert_model, tokenizer=st.session_state.rubert_tokenizer, device=-1 ) result = pipe(input_text) label = result[0]['label'] confidence = result[0]['score'] if label.startswith("LABEL_") and st.session_state.id2label: label_id = int(label.replace("LABEL_", "")) readable_label = st.session_state.id2label.get(label_id, label) else: readable_label = label st.write(f"**Prediction**: {readable_label}") st.write(f"**Confidence**: {confidence:.3f}") except Exception as e: st.error(f"RuBERT inference failed: {e}") with tab2: subtab1, subtab2, subtab3 = st.tabs(["SHAP / LIME", "Attention Map", "Captum Heatmap"]) with subtab1: st.subheader("SHAP: Local Explanation for One Text") if not st.session_state.models: st.info("Train a classical model first") else: model_name = st.selectbox("Model", list(st.session_state.models.keys()), key="shap_model") text_for_explain = st.text_area("Text to explain", "Прекрасная новость о росте экономики!", key="shap_text") top_k = st.slider("Top features to show", 5, 30, 15) if st.button("Explain with SHAP"): try: import shap model = st.session_state.models[model_name] tokens = preprocess_text(text_for_explain, lang='ru', remove_stopwords=False) preprocessed = " ".join(tokens) if st.session_state.vectorizer_type == "TF-IDF": X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray() feature_names = st.session_state.feature_names else: X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased") feature_names = [f"emb_{i}" for i in range(X_input.shape[1])] background = st.session_state.X[:100] # st.write(f"DEBUG: st.session_state.X shape = {st.session_state.X.shape}") # st.write(f"DEBUG: X_input shape = {X_input.shape}") # st.write(f'DEBUG: background shape = {background.shape}') if "tree" in str(type(model)).lower(): explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X_input) else: explainer = shap.KernelExplainer(model.predict_proba, background) shap_values = explainer.shap_values(X_input, nsamples=200) if isinstance(shap_values, list): probs = model.predict_proba(X_input)[0] target_class = int(np.argmax(probs)) single_shap = shap_values[target_class][0] expected_val = explainer.expected_value[target_class] else: sv = shap_values if sv.ndim == 1: single_shap = sv expected_val = explainer.expected_value elif sv.ndim == 2: if sv.shape[0] == 1: single_shap = sv[0] expected_val = explainer.expected_value elif sv.shape[1] == X_input.shape[1]: probs = model.predict_proba(X_input)[0] target_class = int(np.argmax(probs)) single_shap = sv[:, target_class] expected_val = explainer.expected_value[target_class] if isinstance( explainer.expected_value, (list, np.ndarray)) else explainer.expected_value else: single_shap = sv[0] expected_val = explainer.expected_value elif sv.ndim == 3: if sv.shape[0] != 1: raise ValueError("SHAP explanation for more than one sample not supported") probs = model.predict_proba(X_input)[0] target_class = int(np.argmax(probs)) single_shap = sv[0, :, target_class] if isinstance(explainer.expected_value, (list, np.ndarray)) and len( explainer.expected_value) == sv.shape[2]: expected_val = explainer.expected_value[target_class] else: expected_val = explainer.expected_value else: raise ValueError(f"Unsupported SHAP shape: {sv.shape}") single_shap = np.array(single_shap).flatten() if single_shap.shape[0] != X_input.shape[1]: raise ValueError( f"SHAP vector length {single_shap.shape[0]} != input features {X_input.shape[1]}") if st.session_state.vectorizer_type == "TF-IDF": text_vector = X_input[0] nonzero_indices = np.where(text_vector != 0)[0] if len(nonzero_indices) == 0: st.warning("No known words from training vocabulary found in this text.") else: filtered_shap = single_shap[nonzero_indices] filtered_features = text_vector[nonzero_indices] filtered_names = [st.session_state.feature_names[i] for i in nonzero_indices] explanation = shap.Explanation( values=filtered_shap, base_values=expected_val, data=filtered_features, feature_names=filtered_names ) plt.figure(figsize=(10, min(8, top_k * 0.3))) shap.plots.waterfall(explanation, max_display=top_k, show=False) st.pyplot(plt.gcf()) plt.close() else: explanation = shap.Explanation( values=single_shap, base_values=expected_val, data=X_input[0], feature_names=feature_names ) plt.figure(figsize=(10, min(8, top_k * 0.3))) shap.plots.waterfall(explanation, max_display=top_k, show=False) st.pyplot(plt.gcf()) plt.close() except Exception as e: st.error(f"SHAP error: {e}") st.exception(e) with subtab2: st.subheader("Transformer Attention Map") if not st.session_state.rubert_trained: st.info("Train RuBERT first") else: text_att = st.text_area("Text for attention", "Матч завершился победой ЦСКА", key="att_text") layer = st.slider("Layer", 0, 11, 6) head = st.slider("Head", 0, 11, 0) if st.button("Visualize Attention"): try: tokens, attn = get_transformer_attention( st.session_state.rubert_model, st.session_state.rubert_tokenizer, text_att, device="cpu" ) weights = attn[layer, head, :len(tokens), :len(tokens)] fig, ax = plt.subplots(figsize=(10, 4)) sns.heatmap( weights, xticklabels=tokens, yticklabels=tokens, cmap="viridis", ax=ax ) plt.xticks(rotation=45, ha="right") plt.yticks(rotation=0) plt.title(f"Attention: Layer {layer}, Head {head}") st.pyplot(fig) plt.close(fig) except Exception as e: st.error(f"Attention failed: {e}") st.exception(e) with subtab3: st.subheader("Token Importance (Captum)") if not st.session_state.rubert_trained: st.info("Train RuBERT first") else: text_captum = st.text_area("Text for Captum", "Это очень плохая новость для политики", key="captum_text") method = "IntegratedGradients" if st.button("Compute Token Importance"): try: tokens, importance = get_token_importance_captum( st.session_state.rubert_model, st.session_state.rubert_tokenizer, text_captum, device="cpu" ) valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]] if valid: tokens_clean, imp_clean = zip(*valid) indices = np.argsort(np.abs(imp_clean))[-15:][::-1] tokens_top = [tokens_clean[i] for i in indices] imp_top = [imp_clean[i] for i in indices] fig, ax = plt.subplots(figsize=(8, 6)) colors = ["red" if x < 0 else "green" for x in imp_top] ax.barh(range(len(imp_top)), imp_top, color=colors) ax.set_yticks(range(len(imp_top))) ax.set_yticklabels(tokens_top) ax.invert_yaxis() ax.set_xlabel("Attribution Score") ax.set_title("Token Importance") st.pyplot(fig) plt.close(fig) else: st.warning("No valid tokens") except Exception as e: st.error(f"Captum failed: {e}") st.exception(e) with tab3: st.subheader("Model Comparison") if st.session_state.results: df = pd.DataFrame(st.session_state.results).T st.dataframe(df) else: st.info("Train models to see metrics") with tab4: st.subheader("Error Analysis") if st.session_state.X_test is None: st.info("Train models first") else: model_name = st.selectbox("Model for error analysis", list(st.session_state.models.keys()), key="err_model") if st.button("Analyze Errors"): model = st.session_state.models[model_name] y_pred = model.predict(st.session_state.X_test) errors = analyze_errors( st.session_state.y_test, y_pred, st.session_state.test_texts ) st.dataframe(errors[['text', 'true_label', 'pred_label']].head(20))