theformatisvalid commited on
Commit
d745776
·
verified ·
1 Parent(s): 2153792

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +541 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,544 @@
1
- import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
+ import json
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from typing import List, Dict, Any, Union
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
10
+ import shap
11
+
12
+ st.set_page_config(
13
+ page_title="Text Classifiers",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ from text_preprocessing import (
19
+ preprocess_text, get_contextual_embeddings, TextVectorizer
20
+ )
21
+ from classical_classifiers import (
22
+ get_logistic_regression, get_svm_linear, get_random_forest,
23
+ get_gradient_boosting, get_voting_classifier
24
+ )
25
+ from neural_classifiers import get_transformer_classifier
26
+ from model_evaluation import evaluate_model
27
+ from model_interpretation import (
28
+ get_linear_feature_importance,
29
+ analyze_errors,
30
+ get_transformer_attention,
31
+ visualize_attention_weights,
32
+ get_token_importance_captum,
33
+ plot_token_importance
34
+ )
35
+
36
+ import warnings
37
+
38
+ warnings.filterwarnings("ignore")
39
+
40
+ if 'models' not in st.session_state:
41
+ st.session_state.models = {}
42
+ if 'results' not in st.session_state:
43
+ st.session_state.results = {}
44
+ if 'dataset' not in st.session_state:
45
+ st.session_state.dataset = None
46
+ if 'task_type' not in st.session_state:
47
+ st.session_state.task_type = None
48
+ if 'preprocessed' not in st.session_state:
49
+ st.session_state.preprocessed = None
50
+ if 'X' not in st.session_state:
51
+ st.session_state.X = None
52
+ if 'y' not in st.session_state:
53
+ st.session_state.y = None
54
+ if 'feature_names' not in st.session_state:
55
+ st.session_state.feature_names = None
56
+ if 'vectorizer' not in st.session_state:
57
+ st.session_state.vectorizer = None
58
+ if 'vectorizer_type' not in st.session_state:
59
+ st.session_state.vectorizer_type = None
60
+ if 'X_test' not in st.session_state:
61
+ st.session_state.X_test = None
62
+ if 'y_test' not in st.session_state:
63
+ st.session_state.y_test = None
64
+ if 'test_texts' not in st.session_state:
65
+ st.session_state.test_texts = None
66
+ if 'label_encoder' not in st.session_state:
67
+ st.session_state.label_encoder = None
68
+ if 'rubert_model' not in st.session_state:
69
+ st.session_state.rubert_model = None
70
+ if 'rubert_tokenizer' not in st.session_state:
71
+ st.session_state.rubert_tokenizer = None
72
+ if 'rubert_trained' not in st.session_state:
73
+ st.session_state.rubert_trained = False
74
+
75
+ st.sidebar.title("Setup")
76
+
77
+ st.sidebar.subheader("1. Upload Dataset (JSONL)")
78
+ uploaded_file = st.sidebar.file_uploader("Upload .jsonl file", type=["jsonl"])
79
+
80
+ if uploaded_file:
81
+ try:
82
+ raw_data = []
83
+ lines = uploaded_file.getvalue().decode("utf-8").splitlines()
84
+ for line in lines:
85
+ if line.strip():
86
+ raw_data.append(json.loads(line))
87
+ st.session_state.dataset = raw_data
88
+
89
+ first = raw_data[0]
90
+ if 'sentiment' in first:
91
+ st.session_state.task_type = "binary"
92
+ labels = [item['sentiment'] for item in raw_data]
93
+ elif 'category' in first:
94
+ st.session_state.task_type = "multiclass"
95
+ labels = [item['category'] for item in raw_data]
96
+ elif 'tags' in first:
97
+ st.session_state.task_type = "multilabel"
98
+ labels = [item['tags'] for item in raw_data]
99
+ else:
100
+ st.sidebar.error("No label field found")
101
+ st.session_state.task_type = None
102
+ st.session_state.dataset = None
103
+
104
+ if st.session_state.task_type:
105
+ st.sidebar.success(f"Loaded {len(raw_data)} samples. Task: {st.session_state.task_type}")
106
+ if st.session_state.task_type == "binary":
107
+ id2label = {0: "Negative", 1: "Positive"}
108
+ label2id = {"Negative": 0, "Positive": 1}
109
+ elif st.session_state.task_type == "multiclass":
110
+ id2label = {0: "Политика", 1: "Экономика", 2: "Спорт", 3: "Культура"}
111
+ label2id = {"Политика": 0, "Экономика": 1, "Спорт": 2, "Культура": 3}
112
+ else:
113
+ id2label = None
114
+ label2id = None
115
+
116
+ st.session_state.id2label = id2label
117
+ st.session_state.label2id = label2id
118
+ except Exception as e:
119
+ st.sidebar.error(f"Failed to parse JSONL: {e}")
120
+ st.session_state.dataset = None
121
+
122
+ if st.session_state.dataset is not None:
123
+ st.sidebar.subheader("2. Preprocess Text")
124
+ lang = st.sidebar.selectbox("Language", ["ru", "en"], index=0)
125
+ st.session_state.preprocess_lang = 'ru'
126
+ if st.sidebar.button("Run Preprocessing"):
127
+ with st.spinner("Preprocessing..."):
128
+ texts = [item['text'] for item in st.session_state.dataset]
129
+ preprocessed = [preprocess_text(text, lang='ru', remove_stopwords=False) for text in texts]
130
+ st.session_state.preprocessed = preprocessed
131
+ st.sidebar.success("Preprocessing done!")
132
+
133
+ if st.session_state.preprocessed is not None:
134
+ st.sidebar.subheader("3. Vectorization (Classical)")
135
+ vectorizer_type = st.sidebar.selectbox("Method", ["TF-IDF", "RuBERT Embeddings"])
136
+ if st.sidebar.button("Vectorize"):
137
+ with st.spinner("Vectorizing..."):
138
+ if vectorizer_type == "TF-IDF":
139
+ vectorizer = TextVectorizer()
140
+ if not isinstance(st.session_state.preprocessed[0], str):
141
+ st.session_state.preprocessed = [
142
+ ' '.join(text) for text in st.session_state.preprocessed
143
+ ]
144
+ st.sidebar.write("Using max_features=5000")
145
+ X = vectorizer.tfidf(st.session_state.preprocessed, max_features=5000)
146
+ st.sidebar.write(f"X shape: {X.shape}")
147
+ st.session_state.vectorizer = vectorizer
148
+ st.session_state.feature_names = vectorizer.tfidf_vectorizer.get_feature_names_out()
149
+ else:
150
+ X = []
151
+ for text in st.session_state.preprocessed:
152
+ emb = get_contextual_embeddings([text], model_name="DeepPavlov/rubert-base-cased")
153
+ X.append(emb[0])
154
+ X = np.array(X)
155
+ st.session_state.vectorizer = None
156
+ st.session_state.feature_names = None
157
+ st.session_state.X = X
158
+ st.session_state.vectorizer_type = vectorizer_type
159
+
160
+ if st.session_state.task_type == "binary":
161
+ y = np.array([item['sentiment'] for item in st.session_state.dataset])
162
+ elif st.session_state.task_type == "multiclass":
163
+ y = np.array([item['category'] for item in st.session_state.dataset])
164
+ else:
165
+ y = [item['tags'] for item in st.session_state.dataset]
166
+ st.session_state.y = y
167
+ st.sidebar.success("Vectorization complete!")
168
+
169
+ if st.session_state.X is not None:
170
+ st.sidebar.subheader("4. Train Classical Models")
171
+ model_options = ["Logistic Regression", "SVM", "Random Forest", "XGBoost", "Voting"]
172
+ selected_models = st.sidebar.multiselect("Models", model_options)
173
+ if st.sidebar.button("Train Classical Models"):
174
+ from sklearn.model_selection import train_test_split
175
+ from sklearn.preprocessing import LabelEncoder
176
+
177
+ X = st.session_state.X
178
+ y = st.session_state.y
179
+
180
+ if st.session_state.task_type == "multiclass":
181
+ le = LabelEncoder()
182
+ y_encoded = le.fit_transform(y)
183
+ st.session_state.label_encoder = le
184
+ y_for_split = y_encoded
185
+ else:
186
+ y_for_split = y if st.session_state.task_type == "binary" else np.array([len(tags) for tags in y])
187
+
188
+ if st.session_state.task_type == "multilabel":
189
+ split_idx = int(0.8 * len(X))
190
+ X_train, X_test = X[:split_idx], X[split_idx:]
191
+ y_train, y_test = y[:split_idx], y[split_idx:]
192
+ test_texts = [item['text'] for item in st.session_state.dataset[split_idx:]]
193
+ else:
194
+ indices = np.arange(len(X))
195
+ X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
196
+ X, y_for_split, indices, test_size=0.2,
197
+ stratify=y_for_split if st.session_state.task_type != "multilabel" else None,
198
+ random_state=42
199
+ )
200
+ test_texts = [st.session_state.dataset[i]['text'] for i in idx_test]
201
+ if st.session_state.task_type == "multiclass":
202
+ y_train = le.inverse_transform(y_train)
203
+ y_test = le.inverse_transform(y_test)
204
+
205
+ st.session_state.X_test = X_test
206
+ st.session_state.y_test = y_test
207
+ st.session_state.test_texts = test_texts
208
+
209
+ for name in selected_models:
210
+ try:
211
+ with st.spinner(f"Training {name}..."):
212
+ if name == "Logistic Regression":
213
+ model = get_logistic_regression()
214
+ model.fit(X_train, y_train)
215
+ st.session_state.models[name] = model
216
+ elif name == "SVM":
217
+ model = get_svm_linear()
218
+ model.fit(X_train, y_train)
219
+ st.session_state.models[name] = model
220
+ elif name == "Random Forest":
221
+ model = get_random_forest()
222
+ model.fit(X_train, y_train)
223
+ st.session_state.models[name] = model
224
+ elif name == "XGBoost":
225
+ model = get_gradient_boosting("xgb", n_estimators=100)
226
+ model.fit(X_train, y_train)
227
+ st.session_state.models[name] = model
228
+ elif name == "Voting":
229
+ model = get_voting_classifier()
230
+ model.fit(X_train, y_train)
231
+ st.session_state.models[name] = model
232
+
233
+ if st.session_state.task_type != "multilabel":
234
+ metrics = evaluate_model(model, X_test, y_test)
235
+ st.session_state.results[name] = metrics
236
+ except Exception as e:
237
+ st.sidebar.error(f"Failed to train {name}: {e}")
238
+ continue
239
+ st.sidebar.success("Classical models trained!")
240
+
241
+ if st.session_state.dataset is not None and st.session_state.task_type in ["binary", "multiclass"]:
242
+ st.sidebar.subheader("5. Train RuBERT (Transformer)")
243
+ if st.sidebar.button("Train RuBERT"):
244
+ with st.spinner("Loading RuBERT..."):
245
+ try:
246
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
247
+
248
+ num_labels = 2 if st.session_state.task_type == "binary" else 4
249
+ model_name = "DeepPavlov/rubert-base-cased"
250
+
251
+ config = AutoConfig.from_pretrained(
252
+ model_name,
253
+ num_labels=num_labels,
254
+ id2label=st.session_state.id2label,
255
+ label2id=st.session_state.label2id
256
+ )
257
+
258
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
259
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
260
+
261
+ st.session_state.rubert_model = model
262
+ st.session_state.rubert_tokenizer = tokenizer
263
+ st.session_state.rubert_trained = True
264
+ st.sidebar.success("RuBERT loaded with correct labels!")
265
+ except Exception as e:
266
+ st.sidebar.error(f"RuBERT loading failed: {e}")
267
+ st.exception(e)
268
+
269
+ st.title("Text Classifiers")
270
+
271
+ tab1, tab2, tab3, tab4 = st.tabs([
272
+ "Classify",
273
+ "Interpret",
274
+ "Compare",
275
+ "Error Analysis"
276
+ ])
277
+
278
+ with tab1:
279
+ st.subheader("Classify New Text")
280
+ input_text = st.text_area("Enter text", "Сегодня прошёл важный матч по хоккею.")
281
+
282
+ if st.button("Classify"):
283
+ cols = st.columns(2)
284
+ with cols[0]:
285
+ st.markdown("### Classical Models")
286
+ if not st.session_state.models:
287
+ st.info("No classical models trained")
288
+ else:
289
+ tokens = preprocess_text(input_text, lang='ru', remove_stopwords=False)
290
+ preprocessed = " ".join(tokens)
291
+ if st.session_state.vectorizer_type == "TF-IDF":
292
+ X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray()
293
+ else:
294
+ X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased")
295
+
296
+ for name, model in st.session_state.models.items():
297
+ pred = model.predict(X_input)[0]
298
+ st.write(f"**{name}**: {pred}")
299
+ if hasattr(model, "predict_proba"):
300
+ proba = model.predict_proba(X_input)[0]
301
+ st.write(f"Probabilities: {dict(zip(model.classes_, proba))}")
302
+
303
+ with cols[1]:
304
+ st.markdown("### RuBERT")
305
+ if not st.session_state.rubert_trained:
306
+ st.info("Train RuBERT in sidebar")
307
+ else:
308
+ try:
309
+ from transformers import pipeline
310
+
311
+ pipe = pipeline(
312
+ "text-classification",
313
+ model=st.session_state.rubert_model,
314
+ tokenizer=st.session_state.rubert_tokenizer,
315
+ device=-1
316
+ )
317
+ result = pipe(input_text)
318
+ label = result[0]['label']
319
+ confidence = result[0]['score']
320
+
321
+ if label.startswith("LABEL_") and st.session_state.id2label:
322
+ label_id = int(label.replace("LABEL_", ""))
323
+ readable_label = st.session_state.id2label.get(label_id, label)
324
+ else:
325
+ readable_label = label
326
+
327
+ st.write(f"**Prediction**: {readable_label}")
328
+ st.write(f"**Confidence**: {confidence:.3f}")
329
+ except Exception as e:
330
+ st.error(f"RuBERT inference failed: {e}")
331
+
332
+ with tab2:
333
+ subtab1, subtab2, subtab3 = st.tabs(["SHAP / LIME", "Attention Map", "Captum Heatmap"])
334
+
335
+ with subtab1:
336
+ st.subheader("SHAP: Local Explanation for One Text")
337
+ if not st.session_state.models:
338
+ st.info("Train a classical model first")
339
+ else:
340
+ model_name = st.selectbox("Model", list(st.session_state.models.keys()), key="shap_model")
341
+ text_for_explain = st.text_area("Text to explain", "Прекрасная новость о росте экономики!", key="shap_text")
342
+ top_k = st.slider("Top features to show", 5, 30, 15)
343
+
344
+ if st.button("Explain with SHAP"):
345
+ try:
346
+ import shap
347
+
348
+ model = st.session_state.models[model_name]
349
+ tokens = preprocess_text(text_for_explain, lang='ru', remove_stopwords=False)
350
+ preprocessed = " ".join(tokens)
351
+
352
+ if st.session_state.vectorizer_type == "TF-IDF":
353
+ X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray()
354
+ feature_names = st.session_state.feature_names
355
+ else:
356
+ X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased")
357
+ feature_names = [f"emb_{i}" for i in range(X_input.shape[1])]
358
+
359
+ background = st.session_state.X[:100]
360
+ # st.write(f"DEBUG: st.session_state.X shape = {st.session_state.X.shape}")
361
+ # st.write(f"DEBUG: X_input shape = {X_input.shape}")
362
+ # st.write(f'DEBUG: background shape = {background.shape}')
363
+ if "tree" in str(type(model)).lower():
364
+ explainer = shap.TreeExplainer(model)
365
+ shap_values = explainer.shap_values(X_input)
366
+ else:
367
+ explainer = shap.KernelExplainer(model.predict_proba, background)
368
+ shap_values = explainer.shap_values(X_input, nsamples=200)
369
+
370
+ if isinstance(shap_values, list):
371
+ probs = model.predict_proba(X_input)[0]
372
+ target_class = int(np.argmax(probs))
373
+ single_shap = shap_values[target_class][0]
374
+ expected_val = explainer.expected_value[target_class]
375
+ else:
376
+ sv = shap_values
377
+ if sv.ndim == 1:
378
+ single_shap = sv
379
+ expected_val = explainer.expected_value
380
+ elif sv.ndim == 2:
381
+ if sv.shape[0] == 1:
382
+ single_shap = sv[0]
383
+ expected_val = explainer.expected_value
384
+ elif sv.shape[1] == X_input.shape[1]:
385
+ probs = model.predict_proba(X_input)[0]
386
+ target_class = int(np.argmax(probs))
387
+ single_shap = sv[:, target_class]
388
+ expected_val = explainer.expected_value[target_class] if isinstance(
389
+ explainer.expected_value, (list, np.ndarray)) else explainer.expected_value
390
+ else:
391
+ single_shap = sv[0]
392
+ expected_val = explainer.expected_value
393
+ elif sv.ndim == 3:
394
+ if sv.shape[0] != 1:
395
+ raise ValueError("SHAP explanation for more than one sample not supported")
396
+ probs = model.predict_proba(X_input)[0]
397
+ target_class = int(np.argmax(probs))
398
+ single_shap = sv[0, :, target_class]
399
+ if isinstance(explainer.expected_value, (list, np.ndarray)) and len(
400
+ explainer.expected_value) == sv.shape[2]:
401
+ expected_val = explainer.expected_value[target_class]
402
+ else:
403
+ expected_val = explainer.expected_value
404
+ else:
405
+ raise ValueError(f"Unsupported SHAP shape: {sv.shape}")
406
+
407
+ single_shap = np.array(single_shap).flatten()
408
+ if single_shap.shape[0] != X_input.shape[1]:
409
+ raise ValueError(
410
+ f"SHAP vector length {single_shap.shape[0]} != input features {X_input.shape[1]}")
411
+
412
+ if st.session_state.vectorizer_type == "TF-IDF":
413
+ text_vector = X_input[0]
414
+ nonzero_indices = np.where(text_vector != 0)[0]
415
+ if len(nonzero_indices) == 0:
416
+ st.warning("No known words from training vocabulary found in this text.")
417
+ else:
418
+ filtered_shap = single_shap[nonzero_indices]
419
+ filtered_features = text_vector[nonzero_indices]
420
+ filtered_names = [st.session_state.feature_names[i] for i in nonzero_indices]
421
+
422
+ explanation = shap.Explanation(
423
+ values=filtered_shap,
424
+ base_values=expected_val,
425
+ data=filtered_features,
426
+ feature_names=filtered_names
427
+ )
428
+
429
+ plt.figure(figsize=(10, min(8, top_k * 0.3)))
430
+ shap.plots.waterfall(explanation, max_display=top_k, show=False)
431
+ st.pyplot(plt.gcf())
432
+ plt.close()
433
+ else:
434
+ explanation = shap.Explanation(
435
+ values=single_shap,
436
+ base_values=expected_val,
437
+ data=X_input[0],
438
+ feature_names=feature_names
439
+ )
440
+ plt.figure(figsize=(10, min(8, top_k * 0.3)))
441
+ shap.plots.waterfall(explanation, max_display=top_k, show=False)
442
+ st.pyplot(plt.gcf())
443
+ plt.close()
444
+
445
+ except Exception as e:
446
+ st.error(f"SHAP error: {e}")
447
+ st.exception(e)
448
+
449
+ with subtab2:
450
+ st.subheader("Transformer Attention Map")
451
+ if not st.session_state.rubert_trained:
452
+ st.info("Train RuBERT first")
453
+ else:
454
+ text_att = st.text_area("Text for attention", "Матч завершился победой ЦСКА", key="att_text")
455
+ layer = st.slider("Layer", 0, 11, 6)
456
+ head = st.slider("Head", 0, 11, 0)
457
+ if st.button("Visualize Attention"):
458
+ try:
459
+ tokens, attn = get_transformer_attention(
460
+ st.session_state.rubert_model,
461
+ st.session_state.rubert_tokenizer,
462
+ text_att,
463
+ device="cpu"
464
+ )
465
+ weights = attn[layer, head, :len(tokens), :len(tokens)]
466
+
467
+ fig, ax = plt.subplots(figsize=(10, 4))
468
+ sns.heatmap(
469
+ weights,
470
+ xticklabels=tokens,
471
+ yticklabels=tokens,
472
+ cmap="viridis",
473
+ ax=ax
474
+ )
475
+ plt.xticks(rotation=45, ha="right")
476
+ plt.yticks(rotation=0)
477
+ plt.title(f"Attention: Layer {layer}, Head {head}")
478
+ st.pyplot(fig)
479
+ plt.close(fig)
480
+ except Exception as e:
481
+ st.error(f"Attention failed: {e}")
482
+ st.exception(e)
483
+
484
+ with subtab3:
485
+ st.subheader("Token Importance (Captum)")
486
+ if not st.session_state.rubert_trained:
487
+ st.info("Train RuBERT first")
488
+ else:
489
+ text_captum = st.text_area("Text for Captum", "Это очень плохая новость для политики", key="captum_text")
490
+ method = "IntegratedGradients"
491
+ if st.button("Compute Token Importance"):
492
+ try:
493
+ tokens, importance = get_token_importance_captum(
494
+ st.session_state.rubert_model,
495
+ st.session_state.rubert_tokenizer,
496
+ text_captum,
497
+ device="cpu"
498
+ )
499
+ valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]]
500
+ if valid:
501
+ tokens_clean, imp_clean = zip(*valid)
502
+ indices = np.argsort(np.abs(imp_clean))[-15:][::-1]
503
+ tokens_top = [tokens_clean[i] for i in indices]
504
+ imp_top = [imp_clean[i] for i in indices]
505
+
506
+ fig, ax = plt.subplots(figsize=(8, 6))
507
+ colors = ["red" if x < 0 else "green" for x in imp_top]
508
+ ax.barh(range(len(imp_top)), imp_top, color=colors)
509
+ ax.set_yticks(range(len(imp_top)))
510
+ ax.set_yticklabels(tokens_top)
511
+ ax.invert_yaxis()
512
+ ax.set_xlabel("Attribution Score")
513
+ ax.set_title("Token Importance")
514
+ st.pyplot(fig)
515
+ plt.close(fig)
516
+ else:
517
+ st.warning("No valid tokens")
518
+ except Exception as e:
519
+ st.error(f"Captum failed: {e}")
520
+ st.exception(e)
521
+
522
+ with tab3:
523
+ st.subheader("Model Comparison")
524
+ if st.session_state.results:
525
+ df = pd.DataFrame(st.session_state.results).T
526
+ st.dataframe(df)
527
+ else:
528
+ st.info("Train models to see metrics")
529
 
530
+ with tab4:
531
+ st.subheader("Error Analysis")
532
+ if st.session_state.X_test is None:
533
+ st.info("Train models first")
534
+ else:
535
+ model_name = st.selectbox("Model for error analysis", list(st.session_state.models.keys()), key="err_model")
536
+ if st.button("Analyze Errors"):
537
+ model = st.session_state.models[model_name]
538
+ y_pred = model.predict(st.session_state.X_test)
539
+ errors = analyze_errors(
540
+ st.session_state.y_test,
541
+ y_pred,
542
+ st.session_state.test_texts
543
+ )
544
+ st.dataframe(errors[['text', 'true_label', 'pred_label']].head(20))