Spaces:
Sleeping
Sleeping
Kolesnikov Dmitry
commited on
Commit
·
68545bc
1
Parent(s):
83b4881
feat: Попытка навайбкодить 3 и 4 лабораторные
Browse files- requirements.txt +20 -1
- results/vectorization_metrics.csv +3 -2
- src/__pycache__/embeddings_train.cpython-313.pyc +0 -0
- src/__pycache__/text_preprocessing.cpython-313.pyc +0 -0
- src/classical_classifiers.py +400 -0
- src/clustering.py +378 -0
- src/embeddings_train.py +131 -20
- src/imbalance_handling.py +385 -0
- src/model_evaluation.py +359 -0
- src/model_interpretation.py +367 -0
- src/neural_classifiers.py +306 -0
- src/streamlit_app.py +290 -3
- src/text_preprocessing.py +368 -0
- src/text_to_vector.py +403 -0
requirements.txt
CHANGED
|
@@ -23,4 +23,23 @@ umap-learn
|
|
| 23 |
# glove-python-binary # опционально
|
| 24 |
# pymorphy2 # Несовместим с Python 3.13+
|
| 25 |
# transformers # Удалено по запросу пользователя
|
| 26 |
-
# torch # Удалено по запросу пользователя
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# glove-python-binary # опционально
|
| 24 |
# pymorphy2 # Несовместим с Python 3.13+
|
| 25 |
# transformers # Удалено по запросу пользователя
|
| 26 |
+
# torch # Удалено по запросу пользователя
|
| 27 |
+
|
| 28 |
+
# ЛР3 — классификация текстов
|
| 29 |
+
xgboost # опционально
|
| 30 |
+
lightgbm # опционально
|
| 31 |
+
catboost # опционально
|
| 32 |
+
imbalanced-learn
|
| 33 |
+
# autosklearn # опционально, требует системные зависимости
|
| 34 |
+
# tpot # опционально
|
| 35 |
+
# h2o # опционально
|
| 36 |
+
# nlpaug # опционально
|
| 37 |
+
# shap # опционально
|
| 38 |
+
# lime # опционально
|
| 39 |
+
# optuna # опционально
|
| 40 |
+
# hyperopt # опционально
|
| 41 |
+
# tensorflow # опционально, для нейросетей
|
| 42 |
+
|
| 43 |
+
# ЛР4 — кластеризация
|
| 44 |
+
hdbscan
|
| 45 |
+
rank-bm25
|
results/vectorization_metrics.csv
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
Метод,N-граммы,Документов,Признаков,Ненулевых,Плотность,Время fit (с),Время transform (с),Память (MB) ~
|
| 2 |
-
bow,1-
|
| 3 |
-
|
|
|
|
|
|
| 1 |
Метод,N-граммы,Документов,Признаков,Ненулевых,Плотность,Время fit (с),Время transform (с),Память (MB) ~
|
| 2 |
+
bow,1-5,1000,5225,957750,0.183301,0.3194,0.3208,21.92
|
| 3 |
+
onehot,1-5,1000,5181,947768,0.182931,0.5572,0.3303,21.69
|
| 4 |
+
tfidf,1-5,1000,5225,957750,0.183301,0.3047,0.3245,21.92
|
src/__pycache__/embeddings_train.cpython-313.pyc
CHANGED
|
Binary files a/src/__pycache__/embeddings_train.cpython-313.pyc and b/src/__pycache__/embeddings_train.cpython-313.pyc differ
|
|
|
src/__pycache__/text_preprocessing.cpython-313.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
src/classical_classifiers.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Классические методы классификации текстов: логистическая регрессия, SVM,
|
| 3 |
+
случайный лес, градиентный бустинг, ансамбли и AutoML подходы.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from sklearn.linear_model import LogisticRegression
|
| 15 |
+
from sklearn.svm import SVC
|
| 16 |
+
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, BaggingClassifier
|
| 17 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 18 |
+
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
| 19 |
+
from sklearn.multioutput import MultiOutputClassifier
|
| 20 |
+
from sklearn.multiclass import OneVsRestClassifier
|
| 21 |
+
from sklearn.metrics import (
|
| 22 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 23 |
+
roc_auc_score, classification_report, confusion_matrix,
|
| 24 |
+
precision_recall_curve, roc_curve
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import xgboost as xgb
|
| 29 |
+
XGBOOST_AVAILABLE = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
XGBOOST_AVAILABLE = False
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
import lightgbm as lgb
|
| 35 |
+
LIGHTGBM_AVAILABLE = True
|
| 36 |
+
except ImportError:
|
| 37 |
+
LIGHTGBM_AVAILABLE = False
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import catboost as cb
|
| 41 |
+
CATBOOST_AVAILABLE = True
|
| 42 |
+
except ImportError:
|
| 43 |
+
CATBOOST_AVAILABLE = False
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
import autosklearn.classification
|
| 47 |
+
AUTOSKLEARN_AVAILABLE = True
|
| 48 |
+
except ImportError:
|
| 49 |
+
AUTOSKLEARN_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from tpot import TPOTClassifier
|
| 53 |
+
TPOT_AVAILABLE = True
|
| 54 |
+
except ImportError:
|
| 55 |
+
TPOT_AVAILABLE = False
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
import h2o
|
| 59 |
+
from h2o.automl import H2OAutoML
|
| 60 |
+
H2O_AVAILABLE = True
|
| 61 |
+
except ImportError:
|
| 62 |
+
H2O_AVAILABLE = False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class ClassifierConfig:
|
| 67 |
+
"""Конфигурация классификатора."""
|
| 68 |
+
name: str
|
| 69 |
+
model_type: str # lr, svm, rf, xgb, lgb, catboost, ensemble, autosklearn, tpot, h2o
|
| 70 |
+
params: Dict[str, Any] = None
|
| 71 |
+
use_class_weight: bool = True
|
| 72 |
+
multilabel: bool = False # Использовать MultiOutputClassifier для multilabel
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ClassicalClassifiers:
|
| 76 |
+
"""Класс для работы с классическими методами классификации."""
|
| 77 |
+
|
| 78 |
+
def __init__(self, config: ClassifierConfig):
|
| 79 |
+
self.config = config
|
| 80 |
+
self.model = self._create_model()
|
| 81 |
+
self.train_time = 0.0
|
| 82 |
+
self.predict_time = 0.0
|
| 83 |
+
|
| 84 |
+
def _create_model(self):
|
| 85 |
+
"""Создает модель на основе конфигурации."""
|
| 86 |
+
model_type = self.config.model_type.lower()
|
| 87 |
+
params = self.config.params or {}
|
| 88 |
+
|
| 89 |
+
base_model = None
|
| 90 |
+
|
| 91 |
+
if model_type == "lr":
|
| 92 |
+
base_model = LogisticRegression(
|
| 93 |
+
max_iter=1000,
|
| 94 |
+
random_state=42,
|
| 95 |
+
class_weight="balanced" if self.config.use_class_weight else None,
|
| 96 |
+
**params
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
elif model_type == "svm":
|
| 100 |
+
base_model = SVC(
|
| 101 |
+
probability=True,
|
| 102 |
+
random_state=42,
|
| 103 |
+
class_weight="balanced" if self.config.use_class_weight else None,
|
| 104 |
+
**params
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
elif model_type == "rf":
|
| 108 |
+
base_model = RandomForestClassifier(
|
| 109 |
+
n_estimators=100,
|
| 110 |
+
random_state=42,
|
| 111 |
+
class_weight="balanced" if self.config.use_class_weight else None,
|
| 112 |
+
**params
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Обертываем в MultiOutputClassifier для multilabel
|
| 116 |
+
if self.config.multilabel and base_model is not None:
|
| 117 |
+
return MultiOutputClassifier(base_model)
|
| 118 |
+
|
| 119 |
+
if base_model is not None:
|
| 120 |
+
return base_model
|
| 121 |
+
|
| 122 |
+
if model_type == "xgb" and XGBOOST_AVAILABLE:
|
| 123 |
+
model = xgb.XGBClassifier(
|
| 124 |
+
random_state=42,
|
| 125 |
+
eval_metric='mlogloss',
|
| 126 |
+
**params
|
| 127 |
+
)
|
| 128 |
+
return MultiOutputClassifier(model) if self.config.multilabel else model
|
| 129 |
+
|
| 130 |
+
if model_type == "lgb" and LIGHTGBM_AVAILABLE:
|
| 131 |
+
model = lgb.LGBMClassifier(
|
| 132 |
+
random_state=42,
|
| 133 |
+
verbose=-1,
|
| 134 |
+
**params
|
| 135 |
+
)
|
| 136 |
+
return MultiOutputClassifier(model) if self.config.multilabel else model
|
| 137 |
+
|
| 138 |
+
if model_type == "catboost" and CATBOOST_AVAILABLE:
|
| 139 |
+
model = cb.CatBoostClassifier(
|
| 140 |
+
random_state=42,
|
| 141 |
+
verbose=False,
|
| 142 |
+
**params
|
| 143 |
+
)
|
| 144 |
+
return MultiOutputClassifier(model) if self.config.multilabel else model
|
| 145 |
+
|
| 146 |
+
if model_type == "ensemble":
|
| 147 |
+
# Voting Classifier
|
| 148 |
+
estimators = [
|
| 149 |
+
('lr', LogisticRegression(max_iter=1000, random_state=42)),
|
| 150 |
+
('svm', SVC(probability=True, random_state=42)),
|
| 151 |
+
('rf', RandomForestClassifier(n_estimators=50, random_state=42))
|
| 152 |
+
]
|
| 153 |
+
model = VotingClassifier(estimators=estimators, voting='soft')
|
| 154 |
+
return MultiOutputClassifier(model) if self.config.multilabel else model
|
| 155 |
+
|
| 156 |
+
if model_type == "bagging":
|
| 157 |
+
base = DecisionTreeClassifier(random_state=42)
|
| 158 |
+
model = BaggingClassifier(
|
| 159 |
+
base_estimator=base,
|
| 160 |
+
n_estimators=10,
|
| 161 |
+
random_state=42,
|
| 162 |
+
**params
|
| 163 |
+
)
|
| 164 |
+
return MultiOutputClassifier(model) if self.config.multilabel else model
|
| 165 |
+
|
| 166 |
+
if model_type == "autosklearn" and AUTOSKLEARN_AVAILABLE:
|
| 167 |
+
model = autosklearn.classification.AutoSklearnClassifier(
|
| 168 |
+
time_left_for_this_task=300, # 5 минут
|
| 169 |
+
memory_limit=4096,
|
| 170 |
+
**params
|
| 171 |
+
)
|
| 172 |
+
# AutoSklearn может не поддерживать multilabel напрямую
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
if model_type == "tpot" and TPOT_AVAILABLE:
|
| 176 |
+
model = TPOTClassifier(
|
| 177 |
+
generations=5,
|
| 178 |
+
population_size=20,
|
| 179 |
+
verbosity=2,
|
| 180 |
+
random_state=42,
|
| 181 |
+
**params
|
| 182 |
+
)
|
| 183 |
+
# TPOT может не поддерживать multilabel напрямую
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
raise ValueError(f"Неизвестный тип модели: {model_type} или библиотека недоступна")
|
| 187 |
+
|
| 188 |
+
def fit(self, X, y):
|
| 189 |
+
"""Обучение модели."""
|
| 190 |
+
start = time.time()
|
| 191 |
+
self.model.fit(X, y)
|
| 192 |
+
self.train_time = time.time() - start
|
| 193 |
+
return self
|
| 194 |
+
|
| 195 |
+
def predict(self, X):
|
| 196 |
+
"""Предсказание классов."""
|
| 197 |
+
start = time.time()
|
| 198 |
+
predictions = self.model.predict(X)
|
| 199 |
+
self.predict_time = time.time() - start
|
| 200 |
+
return predictions
|
| 201 |
+
|
| 202 |
+
def predict_proba(self, X):
|
| 203 |
+
"""Предсказание вероятностей."""
|
| 204 |
+
if hasattr(self.model, 'predict_proba'):
|
| 205 |
+
return self.model.predict_proba(X)
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
def get_feature_importance(self):
|
| 209 |
+
"""Получение важности признаков (если доступно)."""
|
| 210 |
+
if hasattr(self.model, 'feature_importances_'):
|
| 211 |
+
return self.model.feature_importances_
|
| 212 |
+
elif hasattr(self.model, 'coef_'):
|
| 213 |
+
return np.abs(self.model.coef_[0]) if len(self.model.coef_.shape) > 1 else np.abs(self.model.coef_)
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def evaluate_classifier(y_true, y_pred, y_proba=None,
|
| 218 |
+
task_type: str = "multiclass") -> Dict[str, Any]:
|
| 219 |
+
"""
|
| 220 |
+
Оценка качества классификатора.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
y_true: Истинные метки
|
| 224 |
+
y_pred: Предсказанные метки
|
| 225 |
+
y_proba: Вероятности классов (опционально)
|
| 226 |
+
task_type: Тип задачи (binary, multiclass, multilabel)
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Словарь с метриками
|
| 230 |
+
"""
|
| 231 |
+
metrics = {
|
| 232 |
+
"accuracy": accuracy_score(y_true, y_pred),
|
| 233 |
+
"precision_macro": precision_score(y_true, y_pred, average='macro', zero_division=0),
|
| 234 |
+
"recall_macro": recall_score(y_true, y_pred, average='macro', zero_division=0),
|
| 235 |
+
"f1_macro": f1_score(y_true, y_pred, average='macro', zero_division=0),
|
| 236 |
+
"precision_micro": precision_score(y_true, y_pred, average='micro', zero_division=0),
|
| 237 |
+
"recall_micro": recall_score(y_true, y_pred, average='micro', zero_division=0),
|
| 238 |
+
"f1_micro": f1_score(y_true, y_pred, average='micro', zero_division=0),
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
# ROC-AUC для бинарной классификации
|
| 242 |
+
if task_type == "binary" and y_proba is not None and y_proba.shape[1] == 2:
|
| 243 |
+
try:
|
| 244 |
+
metrics["roc_auc"] = roc_auc_score(y_true, y_proba[:, 1])
|
| 245 |
+
except:
|
| 246 |
+
metrics["roc_auc"] = np.nan
|
| 247 |
+
|
| 248 |
+
# ROC-AUC для многоклассовой (macro)
|
| 249 |
+
elif task_type == "multiclass" and y_proba is not None:
|
| 250 |
+
try:
|
| 251 |
+
metrics["roc_auc_macro"] = roc_auc_score(y_true, y_proba, average='macro', multi_class='ovr')
|
| 252 |
+
except:
|
| 253 |
+
metrics["roc_auc_macro"] = np.nan
|
| 254 |
+
|
| 255 |
+
# Метрики для многометочной классификации
|
| 256 |
+
elif task_type == "multilabel":
|
| 257 |
+
# Для multilabel используем специальные метрики
|
| 258 |
+
from sklearn.metrics import hamming_loss, jaccard_score
|
| 259 |
+
try:
|
| 260 |
+
metrics["hamming_loss"] = hamming_loss(y_true, y_pred)
|
| 261 |
+
metrics["jaccard_score"] = jaccard_score(y_true, y_pred, average='macro', zero_division=0)
|
| 262 |
+
# ROC-AUC для multilabel (каждый класс отдельно, затем усреднение)
|
| 263 |
+
if y_proba is not None:
|
| 264 |
+
try:
|
| 265 |
+
metrics["roc_auc_macro"] = roc_auc_score(y_true, y_proba, average='macro')
|
| 266 |
+
except:
|
| 267 |
+
metrics["roc_auc_macro"] = np.nan
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(f"Ошибка при вычислении метрик multilabel: {e}")
|
| 270 |
+
|
| 271 |
+
return metrics
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def cross_validate_classifier(model, X, y, cv=5, scoring='f1_macro'):
|
| 275 |
+
"""Кросс-валидация классификатора."""
|
| 276 |
+
cv_scores = cross_val_score(model, X, y, cv=StratifiedKFold(n_splits=cv, shuffle=True, random_state=42),
|
| 277 |
+
scoring=scoring)
|
| 278 |
+
return {
|
| 279 |
+
"mean": cv_scores.mean(),
|
| 280 |
+
"std": cv_scores.std(),
|
| 281 |
+
"scores": cv_scores.tolist()
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def compare_classifiers(X_train, y_train, X_test, y_test,
|
| 286 |
+
configs: List[ClassifierConfig],
|
| 287 |
+
task_type: str = "multiclass",
|
| 288 |
+
cv: Optional[int] = None) -> pd.DataFrame:
|
| 289 |
+
"""
|
| 290 |
+
Сравнение нескольких классификаторов.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
X_train: Обучающие признаки
|
| 294 |
+
y_train: Обучающие метки
|
| 295 |
+
X_test: Тестовые признаки
|
| 296 |
+
y_test: Тестовые метки
|
| 297 |
+
configs: Список конфигураций классификаторов
|
| 298 |
+
task_type: Тип задачи (binary, multiclass, multilabel)
|
| 299 |
+
cv: Количество фолдов для кросс-валидации (опционально)
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
DataFrame с результатами сравнения
|
| 303 |
+
"""
|
| 304 |
+
# Определяем, является ли задача multilabel
|
| 305 |
+
is_multilabel = task_type == "multilabel"
|
| 306 |
+
if is_multilabel:
|
| 307 |
+
# Обновляем конфигурации для multilabel
|
| 308 |
+
for cfg in configs:
|
| 309 |
+
cfg.multilabel = True
|
| 310 |
+
|
| 311 |
+
results = []
|
| 312 |
+
|
| 313 |
+
for cfg in configs:
|
| 314 |
+
try:
|
| 315 |
+
classifier = ClassicalClassifiers(cfg)
|
| 316 |
+
|
| 317 |
+
# Обучение
|
| 318 |
+
classifier.fit(X_train, y_train)
|
| 319 |
+
|
| 320 |
+
# Предсказания
|
| 321 |
+
y_pred = classifier.predict(X_test)
|
| 322 |
+
y_proba = classifier.predict_proba(X_test)
|
| 323 |
+
|
| 324 |
+
# Для multilabel y_pred может быть 2D, нужно преобразовать
|
| 325 |
+
if is_multilabel and len(y_pred.shape) == 2:
|
| 326 |
+
# y_pred уже в правильном формате для multilabel
|
| 327 |
+
pass
|
| 328 |
+
elif is_multilabel:
|
| 329 |
+
# Если модель вернула 1D, преобразуем
|
| 330 |
+
y_pred = y_pred.reshape(-1, 1) if len(y_pred.shape) == 1 else y_pred
|
| 331 |
+
|
| 332 |
+
# Метрики
|
| 333 |
+
metrics = evaluate_classifier(y_test, y_pred, y_proba, task_type)
|
| 334 |
+
|
| 335 |
+
# Кросс-валидация (если запрошена)
|
| 336 |
+
cv_results = None
|
| 337 |
+
if cv:
|
| 338 |
+
cv_results = cross_validate_classifier(classifier.model, X_train, y_train, cv=cv)
|
| 339 |
+
|
| 340 |
+
result = {
|
| 341 |
+
"Модель": cfg.name,
|
| 342 |
+
"Тип": cfg.model_type,
|
| 343 |
+
"Точность": round(metrics["accuracy"], 4),
|
| 344 |
+
"Precision (macro)": round(metrics["precision_macro"], 4),
|
| 345 |
+
"Recall (macro)": round(metrics["recall_macro"], 4),
|
| 346 |
+
"F1 (macro)": round(metrics["f1_macro"], 4),
|
| 347 |
+
"F1 (micro)": round(metrics["f1_micro"], 4),
|
| 348 |
+
"Время обучения (с)": round(classifier.train_time, 2),
|
| 349 |
+
"Время предсказания (с)": round(classifier.predict_time, 4),
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
if "roc_auc" in metrics:
|
| 353 |
+
result["ROC-AUC"] = round(metrics["roc_auc"], 4)
|
| 354 |
+
elif "roc_auc_macro" in metrics:
|
| 355 |
+
result["ROC-AUC (macro)"] = round(metrics["roc_auc_macro"], 4)
|
| 356 |
+
|
| 357 |
+
# Дополнительные метрики для multilabel
|
| 358 |
+
if task_type == "multilabel":
|
| 359 |
+
if "hamming_loss" in metrics:
|
| 360 |
+
result["Hamming Loss"] = round(metrics["hamming_loss"], 4)
|
| 361 |
+
if "jaccard_score" in metrics:
|
| 362 |
+
result["Jaccard Score"] = round(metrics["jaccard_score"], 4)
|
| 363 |
+
|
| 364 |
+
if cv_results:
|
| 365 |
+
result["CV F1 (mean)"] = round(cv_results["mean"], 4)
|
| 366 |
+
result["CV F1 (std)"] = round(cv_results["std"], 4)
|
| 367 |
+
|
| 368 |
+
results.append(result)
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
print(f"Ошибка при обучении {cfg.name}: {e}")
|
| 372 |
+
results.append({
|
| 373 |
+
"Модель": cfg.name,
|
| 374 |
+
"Тип": cfg.model_type,
|
| 375 |
+
"Ошибка": str(e)
|
| 376 |
+
})
|
| 377 |
+
|
| 378 |
+
return pd.DataFrame(results)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
# Тестирование
|
| 383 |
+
from sklearn.datasets import make_classification
|
| 384 |
+
from sklearn.model_selection import train_test_split
|
| 385 |
+
|
| 386 |
+
X, y = make_classification(n_samples=1000, n_features=20, n_classes=3, random_state=42)
|
| 387 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 388 |
+
|
| 389 |
+
configs = [
|
| 390 |
+
ClassifierConfig(name="Logistic Regression", model_type="lr"),
|
| 391 |
+
ClassifierConfig(name="SVM", model_type="svm", params={"kernel": "linear"}),
|
| 392 |
+
ClassifierConfig(name="Random Forest", model_type="rf"),
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
if XGBOOST_AVAILABLE:
|
| 396 |
+
configs.append(ClassifierConfig(name="XGBoost", model_type="xgb"))
|
| 397 |
+
|
| 398 |
+
results_df = compare_classifiers(X_train, y_train, X_test, y_test, configs)
|
| 399 |
+
print(results_df)
|
| 400 |
+
|
src/clustering.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Модуль для кластеризации текстовых данных.
|
| 3 |
+
Реализует все основные классические методы кластеризации:
|
| 4 |
+
- Центроидные: k-Means, Mini-Batch k-Means, Spherical k-Means
|
| 5 |
+
- Плотностные: DBSCAN, HDBSCAN
|
| 6 |
+
- Иерархические: агломеративная кластеризация
|
| 7 |
+
- Вероятностные: Gaussian Mixture Models, LDA
|
| 8 |
+
- Графовые: спектральная кластеризация
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from sklearn.cluster import (
|
| 19 |
+
KMeans, MiniBatchKMeans, DBSCAN, AgglomerativeClustering,
|
| 20 |
+
SpectralClustering
|
| 21 |
+
)
|
| 22 |
+
from sklearn.mixture import GaussianMixture
|
| 23 |
+
from sklearn.metrics import (
|
| 24 |
+
silhouette_score, calinski_harabasz_score, davies_bouldin_score,
|
| 25 |
+
adjusted_rand_score, normalized_mutual_info_score, v_measure_score
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import hdbscan
|
| 30 |
+
HDBSCAN_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
HDBSCAN_AVAILABLE = False
|
| 33 |
+
print("⚠️ hdbscan не установлен. HDBSCAN недоступен. Установите: pip install hdbscan")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from gensim.models import LdaModel
|
| 37 |
+
from gensim.corpora import Dictionary
|
| 38 |
+
GENSIM_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
GENSIM_AVAILABLE = False
|
| 41 |
+
print("⚠️ gensim не установлен. LDA недоступен.")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class ClusteringConfig:
|
| 46 |
+
"""Конфигурация алгоритма кластеризации."""
|
| 47 |
+
method: str # kmeans, minibatch_kmeans, spherical_kmeans, dbscan, hdbscan,
|
| 48 |
+
# agglomerative, gmm, lda, spectral
|
| 49 |
+
n_clusters: Optional[int] = None # Для методов, требующих число кластеров
|
| 50 |
+
random_state: int = 42
|
| 51 |
+
# Специфичные параметры
|
| 52 |
+
eps: float = 0.5 # Для DBSCAN
|
| 53 |
+
min_samples: int = 5 # Для DBSCAN/HDBSCAN
|
| 54 |
+
linkage: str = "ward" # Для Agglomerative
|
| 55 |
+
metric: str = "euclidean" # Для Agglomerative
|
| 56 |
+
n_components: int = 10 # Для LDA
|
| 57 |
+
n_neighbors: int = 10 # Для Spectral
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ClusteringAlgorithms:
|
| 61 |
+
"""Класс для работы с алгоритмами кластеризации."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, config: ClusteringConfig):
|
| 64 |
+
self.config = config
|
| 65 |
+
self.model = self._create_model()
|
| 66 |
+
self.labels_ = None
|
| 67 |
+
self.fit_time = 0.0
|
| 68 |
+
self.predict_time = 0.0
|
| 69 |
+
|
| 70 |
+
def _create_model(self):
|
| 71 |
+
"""Создает модель кластеризации."""
|
| 72 |
+
method = self.config.method.lower()
|
| 73 |
+
|
| 74 |
+
if method == "kmeans":
|
| 75 |
+
if self.config.n_clusters is None:
|
| 76 |
+
raise ValueError("Для k-Means требуется n_clusters")
|
| 77 |
+
return KMeans(
|
| 78 |
+
n_clusters=self.config.n_clusters,
|
| 79 |
+
random_state=self.config.random_state,
|
| 80 |
+
n_init=10
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
elif method == "minibatch_kmeans":
|
| 84 |
+
if self.config.n_clusters is None:
|
| 85 |
+
raise ValueError("Для Mini-Batch k-Means требуется n_clusters")
|
| 86 |
+
return MiniBatchKMeans(
|
| 87 |
+
n_clusters=self.config.n_clusters,
|
| 88 |
+
random_state=self.config.random_state,
|
| 89 |
+
n_init=3,
|
| 90 |
+
batch_size=256
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
elif method == "spherical_kmeans":
|
| 94 |
+
# Spherical k-Means через k-Means с нормализацией
|
| 95 |
+
if self.config.n_clusters is None:
|
| 96 |
+
raise ValueError("Для Spherical k-Means требуется n_clusters")
|
| 97 |
+
return KMeans(
|
| 98 |
+
n_clusters=self.config.n_clusters,
|
| 99 |
+
random_state=self.config.random_state,
|
| 100 |
+
n_init=10
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
elif method == "dbscan":
|
| 104 |
+
return DBSCAN(
|
| 105 |
+
eps=self.config.eps,
|
| 106 |
+
min_samples=self.config.min_samples,
|
| 107 |
+
metric='cosine' # Для текстов обычно используется cosine
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
elif method == "hdbscan":
|
| 111 |
+
if not HDBSCAN_AVAILABLE:
|
| 112 |
+
raise ImportError("hdbscan не установлен. Установите: pip install hdbscan")
|
| 113 |
+
return hdbscan.HDBSCAN(
|
| 114 |
+
min_cluster_size=self.config.min_samples,
|
| 115 |
+
metric='euclidean',
|
| 116 |
+
cluster_selection_method='eom'
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
elif method == "agglomerative":
|
| 120 |
+
if self.config.n_clusters is None:
|
| 121 |
+
raise ValueError("Для Agglomerative требуется n_clusters")
|
| 122 |
+
return AgglomerativeClustering(
|
| 123 |
+
n_clusters=self.config.n_clusters,
|
| 124 |
+
linkage=self.config.linkage,
|
| 125 |
+
metric=self.config.metric
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
elif method == "gmm":
|
| 129 |
+
if self.config.n_clusters is None:
|
| 130 |
+
raise ValueError("Для GMM требуется n_clusters")
|
| 131 |
+
return GaussianMixture(
|
| 132 |
+
n_components=self.config.n_clusters,
|
| 133 |
+
random_state=self.config.random_state,
|
| 134 |
+
max_iter=100
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
elif method == "spectral":
|
| 138 |
+
if self.config.n_clusters is None:
|
| 139 |
+
raise ValueError("Для Spectral требуется n_clusters")
|
| 140 |
+
return SpectralClustering(
|
| 141 |
+
n_clusters=self.config.n_clusters,
|
| 142 |
+
random_state=self.config.random_state,
|
| 143 |
+
affinity='nearest_neighbors',
|
| 144 |
+
n_neighbors=self.config.n_neighbors
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
elif method == "lda":
|
| 148 |
+
# LDA обрабатывается отдельно, так как это тематическая модель
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(f"Неизвестный метод кластеризации: {method}")
|
| 153 |
+
|
| 154 |
+
def fit(self, X: np.ndarray):
|
| 155 |
+
"""Обучение модели кластеризации."""
|
| 156 |
+
start = time.time()
|
| 157 |
+
|
| 158 |
+
# Spherical k-Means требует нормализации
|
| 159 |
+
if self.config.method.lower() == "spherical_kmeans":
|
| 160 |
+
from sklearn.preprocessing import normalize
|
| 161 |
+
X = normalize(X, norm='l2')
|
| 162 |
+
|
| 163 |
+
# Для DBSCAN/HDBSCAN с cosine метрикой также нормализуем
|
| 164 |
+
if self.config.method.lower() in ["dbscan", "hdbscan"] and self.config.metric == "cosine":
|
| 165 |
+
from sklearn.preprocessing import normalize
|
| 166 |
+
X = normalize(X, norm='l2')
|
| 167 |
+
|
| 168 |
+
if self.config.method.lower() == "lda":
|
| 169 |
+
# LDA обрабатывается отдельно
|
| 170 |
+
raise NotImplementedError("LDA используйте метод fit_lda")
|
| 171 |
+
|
| 172 |
+
self.model.fit(X)
|
| 173 |
+
|
| 174 |
+
if hasattr(self.model, 'labels_'):
|
| 175 |
+
self.labels_ = self.model.labels_
|
| 176 |
+
elif hasattr(self.model, 'predict'):
|
| 177 |
+
self.labels_ = self.model.predict(X)
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError("Модель не вернула метки кластеров")
|
| 180 |
+
|
| 181 |
+
self.fit_time = time.time() - start
|
| 182 |
+
return self
|
| 183 |
+
|
| 184 |
+
def fit_lda(self, texts: List[str], dictionary: Optional[Any] = None):
|
| 185 |
+
"""
|
| 186 |
+
Обучение LDA модели для кластеризации по темам.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
texts: Список текстов (уже токенизированных)
|
| 190 |
+
dictionary: Gensim Dictionary (опционально)
|
| 191 |
+
"""
|
| 192 |
+
if not GENSIM_AVAILABLE:
|
| 193 |
+
raise ImportError("gensim не установлен. Установите: pip install gensim")
|
| 194 |
+
|
| 195 |
+
if self.config.n_clusters is None:
|
| 196 |
+
raise ValueError("Для LDA требуется n_clusters (число тем)")
|
| 197 |
+
|
| 198 |
+
from gensim.utils import simple_preprocess
|
| 199 |
+
|
| 200 |
+
# Токенизация, если нужно
|
| 201 |
+
tokenized_texts = []
|
| 202 |
+
for text in texts:
|
| 203 |
+
if isinstance(text, str):
|
| 204 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 205 |
+
else:
|
| 206 |
+
tokens = text
|
| 207 |
+
tokenized_texts.append(tokens)
|
| 208 |
+
|
| 209 |
+
# Создаем словарь
|
| 210 |
+
if dictionary is None:
|
| 211 |
+
dictionary = Dictionary(tokenized_texts)
|
| 212 |
+
dictionary.filter_extremes(no_below=2, no_above=0.5)
|
| 213 |
+
|
| 214 |
+
# Создаем корпус
|
| 215 |
+
corpus = [dictionary.doc2bow(text) for text in tokenized_texts]
|
| 216 |
+
|
| 217 |
+
# Обучаем LDA
|
| 218 |
+
start = time.time()
|
| 219 |
+
lda_model = LdaModel(
|
| 220 |
+
corpus=corpus,
|
| 221 |
+
num_topics=self.config.n_clusters,
|
| 222 |
+
id2word=dictionary,
|
| 223 |
+
random_state=self.config.random_state,
|
| 224 |
+
passes=10,
|
| 225 |
+
alpha='auto',
|
| 226 |
+
per_word_topics=True
|
| 227 |
+
)
|
| 228 |
+
self.fit_time = time.time() - start
|
| 229 |
+
|
| 230 |
+
# Получаем метки кластеров (темы) для каждого документа
|
| 231 |
+
self.labels_ = []
|
| 232 |
+
for doc in corpus:
|
| 233 |
+
topic_dist = lda_model.get_document_topics(doc, minimum_probability=0.0)
|
| 234 |
+
# Берем тему с максимальной вероятностью
|
| 235 |
+
best_topic = max(topic_dist, key=lambda x: x[1])[0]
|
| 236 |
+
self.labels_.append(best_topic)
|
| 237 |
+
|
| 238 |
+
self.labels_ = np.array(self.labels_)
|
| 239 |
+
self.model = lda_model
|
| 240 |
+
self.dictionary = dictionary
|
| 241 |
+
|
| 242 |
+
return self
|
| 243 |
+
|
| 244 |
+
def predict(self, X: np.ndarray):
|
| 245 |
+
"""Предсказание кластеров для новых данных."""
|
| 246 |
+
start = time.time()
|
| 247 |
+
|
| 248 |
+
if self.config.method.lower() == "lda":
|
| 249 |
+
raise NotImplementedError("LDA predict требует отдел��ной реализации")
|
| 250 |
+
|
| 251 |
+
if hasattr(self.model, 'predict'):
|
| 252 |
+
predictions = self.model.predict(X)
|
| 253 |
+
else:
|
| 254 |
+
# Для DBSCAN и некоторых других методов
|
| 255 |
+
predictions = self.model.fit_predict(X)
|
| 256 |
+
|
| 257 |
+
self.predict_time = time.time() - start
|
| 258 |
+
return predictions
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def evaluate_clustering(X: np.ndarray, labels: np.ndarray,
|
| 262 |
+
y_true: Optional[np.ndarray] = None) -> Dict[str, float]:
|
| 263 |
+
"""
|
| 264 |
+
Оценка качества кластеризации.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
X: Признаки
|
| 268 |
+
labels: Предсказанные метки кластеров
|
| 269 |
+
y_true: Истинные метки (опционально, для внешних метрик)
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Словарь с метриками
|
| 273 |
+
"""
|
| 274 |
+
metrics = {}
|
| 275 |
+
|
| 276 |
+
# Внутренние метрики
|
| 277 |
+
# Удаляем шумовые точки (-1) для метрик
|
| 278 |
+
valid_mask = labels != -1
|
| 279 |
+
if valid_mask.sum() > 1:
|
| 280 |
+
X_valid = X[valid_mask]
|
| 281 |
+
labels_valid = labels[valid_mask]
|
| 282 |
+
|
| 283 |
+
if len(np.unique(labels_valid)) > 1:
|
| 284 |
+
metrics["silhouette"] = silhouette_score(X_valid, labels_valid)
|
| 285 |
+
metrics["calinski_harabasz"] = calinski_harabasz_score(X_valid, labels_valid)
|
| 286 |
+
metrics["davies_bouldin"] = davies_bouldin_score(X_valid, labels_valid)
|
| 287 |
+
else:
|
| 288 |
+
metrics["silhouette"] = -1.0
|
| 289 |
+
metrics["calinski_harabasz"] = 0.0
|
| 290 |
+
metrics["davies_bouldin"] = np.inf
|
| 291 |
+
|
| 292 |
+
# Внешние метрики (если есть истинные метки)
|
| 293 |
+
if y_true is not None:
|
| 294 |
+
metrics["adjusted_rand_index"] = adjusted_rand_score(y_true, labels)
|
| 295 |
+
metrics["normalized_mutual_info"] = normalized_mutual_info_score(y_true, labels)
|
| 296 |
+
metrics["v_measure"] = v_measure_score(y_true, labels)
|
| 297 |
+
|
| 298 |
+
# Статистика кластеров
|
| 299 |
+
unique_labels, counts = np.unique(labels, return_counts=True)
|
| 300 |
+
metrics["n_clusters"] = len(unique_labels[unique_labels != -1]) # Исключаем шум
|
| 301 |
+
metrics["n_noise"] = (labels == -1).sum() if -1 in labels else 0
|
| 302 |
+
metrics["avg_cluster_size"] = counts[unique_labels != -1].mean() if len(counts[unique_labels != -1]) > 0 else 0
|
| 303 |
+
|
| 304 |
+
return metrics
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def compare_clustering_methods(X: np.ndarray,
|
| 308 |
+
configs: List[ClusteringConfig],
|
| 309 |
+
y_true: Optional[np.ndarray] = None) -> pd.DataFrame:
|
| 310 |
+
"""
|
| 311 |
+
Сравнение нескольких методов кластеризации.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
X: Признаки
|
| 315 |
+
configs: Список конфигураций
|
| 316 |
+
y_true: Истинные метки (опционально)
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
DataFrame с результатами сравнения
|
| 320 |
+
"""
|
| 321 |
+
results = []
|
| 322 |
+
|
| 323 |
+
for cfg in configs:
|
| 324 |
+
try:
|
| 325 |
+
clusterer = ClusteringAlgorithms(cfg)
|
| 326 |
+
|
| 327 |
+
if cfg.method.lower() == "lda":
|
| 328 |
+
# LDA требует тексты, пропускаем в этой функции
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
clusterer.fit(X)
|
| 332 |
+
metrics = evaluate_clustering(X, clusterer.labels_, y_true)
|
| 333 |
+
|
| 334 |
+
result = {
|
| 335 |
+
"Метод": cfg.method,
|
| 336 |
+
"Число кластеров": metrics.get("n_clusters", cfg.n_clusters),
|
| 337 |
+
"Шумовые точки": metrics.get("n_noise", 0),
|
| 338 |
+
"Silhouette": round(metrics.get("silhouette", -1), 4),
|
| 339 |
+
"Calinski-Harabasz": round(metrics.get("calinski_harabasz", 0), 4),
|
| 340 |
+
"Davies-Bouldin": round(metrics.get("davies_bouldin", np.inf), 4),
|
| 341 |
+
"Время обучения (с)": round(clusterer.fit_time, 2),
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
if y_true is not None:
|
| 345 |
+
result["ARI"] = round(metrics.get("adjusted_rand_index", 0), 4)
|
| 346 |
+
result["NMI"] = round(metrics.get("normalized_mutual_info", 0), 4)
|
| 347 |
+
result["V-measure"] = round(metrics.get("v_measure", 0), 4)
|
| 348 |
+
|
| 349 |
+
results.append(result)
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f"Ошибка при кластеризации методом {cfg.method}: {e}")
|
| 353 |
+
results.append({
|
| 354 |
+
"Метод": cfg.method,
|
| 355 |
+
"Ошибка": str(e)
|
| 356 |
+
})
|
| 357 |
+
|
| 358 |
+
return pd.DataFrame(results)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
# Тестирование
|
| 363 |
+
from sklearn.datasets import make_blobs
|
| 364 |
+
|
| 365 |
+
X, y = make_blobs(n_samples=300, centers=4, random_state=42)
|
| 366 |
+
|
| 367 |
+
configs = [
|
| 368 |
+
ClusteringConfig(method="kmeans", n_clusters=4),
|
| 369 |
+
ClusteringConfig(method="dbscan", eps=0.5, min_samples=5),
|
| 370 |
+
ClusteringConfig(method="agglomerative", n_clusters=4, linkage="ward"),
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
if HDBSCAN_AVAILABLE:
|
| 374 |
+
configs.append(ClusteringConfig(method="hdbscan", min_samples=5))
|
| 375 |
+
|
| 376 |
+
results_df = compare_clustering_methods(X, configs, y_true=y)
|
| 377 |
+
print(results_df)
|
| 378 |
+
|
src/embeddings_train.py
CHANGED
|
@@ -17,10 +17,21 @@ from gensim.models import Word2Vec, FastText, Doc2Vec
|
|
| 17 |
from gensim.models.doc2vec import TaggedDocument
|
| 18 |
from gensim.utils import simple_preprocess
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
@dataclass
|
| 22 |
class TrainConfig:
|
| 23 |
-
model_type: str # w2v | fasttext | doc2vec
|
| 24 |
vector_size: int = 300
|
| 25 |
window: int = 8
|
| 26 |
min_count: int = 2
|
|
@@ -31,6 +42,9 @@ class TrainConfig:
|
|
| 31 |
negative: int = 5
|
| 32 |
hs: int = 0
|
| 33 |
seed: int = 42
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def _tokenize_corpus(texts: Iterable[str]) -> List[List[str]]:
|
|
@@ -104,6 +118,25 @@ def train_doc2vec(texts: Iterable[str], cfg: TrainConfig) -> Doc2Vec:
|
|
| 104 |
return model
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def train_model(texts: Iterable[str], cfg: TrainConfig):
|
| 108 |
t0 = time.time()
|
| 109 |
if cfg.model_type == "w2v":
|
|
@@ -112,15 +145,22 @@ def train_model(texts: Iterable[str], cfg: TrainConfig):
|
|
| 112 |
model = train_fasttext(texts, cfg)
|
| 113 |
elif cfg.model_type == "doc2vec":
|
| 114 |
model = train_doc2vec(texts, cfg)
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
-
raise ValueError("model_type должен быть 'w2v', 'fasttext' или '
|
| 117 |
train_time = time.time() - t0
|
| 118 |
return model, train_time
|
| 119 |
|
| 120 |
|
| 121 |
def save_model(model, out_path: str) -> None:
|
| 122 |
Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def load_model(path: str):
|
|
@@ -134,36 +174,107 @@ def load_model(path: str):
|
|
| 134 |
return _FT.load(path)
|
| 135 |
except Exception:
|
| 136 |
pass
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def evaluate_neighbors(model, test_words: List[str], topn: int = 10) -> Dict[str, List[Tuple[str, float]]]:
|
| 141 |
results: Dict[str, List[Tuple[str, float]]] = {}
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
return results
|
| 149 |
|
| 150 |
|
| 151 |
def cosine_similarity(model, word_pairs: List[Tuple[str, str]]) -> List[Tuple[str, str, float]]:
|
| 152 |
out: List[Tuple[str, str, float]] = []
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
return out
|
| 160 |
|
| 161 |
|
| 162 |
def word_analogy(model, a: str, b: str, c: str, topn: int = 10) -> List[Tuple[str, float]]:
|
| 163 |
-
|
| 164 |
-
if
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def export_training_report(cfg: TrainConfig, train_time: float, model_path: str, extra: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
|
|
|
|
| 17 |
from gensim.models.doc2vec import TaggedDocument
|
| 18 |
from gensim.utils import simple_preprocess
|
| 19 |
|
| 20 |
+
try:
|
| 21 |
+
from glove import Glove, Corpus
|
| 22 |
+
GLOVE_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
try:
|
| 25 |
+
from glove_python import Glove, Corpus
|
| 26 |
+
GLOVE_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
GLOVE_AVAILABLE = False
|
| 29 |
+
print("⚠️ GloVe не установлен. Установите: pip install glove-python-binary")
|
| 30 |
+
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class TrainConfig:
|
| 34 |
+
model_type: str # w2v | fasttext | doc2vec | glove
|
| 35 |
vector_size: int = 300
|
| 36 |
window: int = 8
|
| 37 |
min_count: int = 2
|
|
|
|
| 42 |
negative: int = 5
|
| 43 |
hs: int = 0
|
| 44 |
seed: int = 42
|
| 45 |
+
# GloVe специфичные параметры
|
| 46 |
+
alpha: float = 0.75 # для GloVe
|
| 47 |
+
x_max: int = 100 # для GloVe
|
| 48 |
|
| 49 |
|
| 50 |
def _tokenize_corpus(texts: Iterable[str]) -> List[List[str]]:
|
|
|
|
| 118 |
return model
|
| 119 |
|
| 120 |
|
| 121 |
+
def train_glove(texts: Iterable[str], cfg: TrainConfig):
|
| 122 |
+
"""Обучает GloVe модель."""
|
| 123 |
+
if not GLOVE_AVAILABLE:
|
| 124 |
+
raise ImportError("GloVe не установлен. Установите: pip install glove-python-binary")
|
| 125 |
+
|
| 126 |
+
sentences = _tokenize_corpus(texts)
|
| 127 |
+
|
| 128 |
+
# Создаем корпус для GloVe
|
| 129 |
+
corpus = Corpus()
|
| 130 |
+
corpus.fit(sentences, window=cfg.window)
|
| 131 |
+
|
| 132 |
+
# Обучаем модель
|
| 133 |
+
model = Glove(no_components=cfg.vector_size, learning_rate=0.05)
|
| 134 |
+
model.fit(corpus.matrix, epochs=cfg.epochs, no_threads=cfg.workers, verbose=True)
|
| 135 |
+
model.add_dictionary(corpus.dictionary)
|
| 136 |
+
|
| 137 |
+
return model
|
| 138 |
+
|
| 139 |
+
|
| 140 |
def train_model(texts: Iterable[str], cfg: TrainConfig):
|
| 141 |
t0 = time.time()
|
| 142 |
if cfg.model_type == "w2v":
|
|
|
|
| 145 |
model = train_fasttext(texts, cfg)
|
| 146 |
elif cfg.model_type == "doc2vec":
|
| 147 |
model = train_doc2vec(texts, cfg)
|
| 148 |
+
elif cfg.model_type == "glove":
|
| 149 |
+
model = train_glove(texts, cfg)
|
| 150 |
else:
|
| 151 |
+
raise ValueError("model_type должен быть 'w2v', 'fasttext', 'doc2vec' или 'glove'")
|
| 152 |
train_time = time.time() - t0
|
| 153 |
return model, train_time
|
| 154 |
|
| 155 |
|
| 156 |
def save_model(model, out_path: str) -> None:
|
| 157 |
Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
|
| 158 |
+
# GloVe имеет другой метод сохранения
|
| 159 |
+
if GLOVE_AVAILABLE and hasattr(model, 'word_vectors') and hasattr(model, 'dictionary'):
|
| 160 |
+
model.save(out_path)
|
| 161 |
+
else:
|
| 162 |
+
# Gensim модели
|
| 163 |
+
model.save(out_path)
|
| 164 |
|
| 165 |
|
| 166 |
def load_model(path: str):
|
|
|
|
| 174 |
return _FT.load(path)
|
| 175 |
except Exception:
|
| 176 |
pass
|
| 177 |
+
try:
|
| 178 |
+
return _D2V.load(path)
|
| 179 |
+
except Exception:
|
| 180 |
+
pass
|
| 181 |
+
# Пробуем загрузить GloVe
|
| 182 |
+
if GLOVE_AVAILABLE:
|
| 183 |
+
try:
|
| 184 |
+
from glove import Glove
|
| 185 |
+
return Glove.load(path)
|
| 186 |
+
except Exception:
|
| 187 |
+
pass
|
| 188 |
+
raise ValueError(f"Не удалось загрузить модель из {path}")
|
| 189 |
|
| 190 |
|
| 191 |
def evaluate_neighbors(model, test_words: List[str], topn: int = 10) -> Dict[str, List[Tuple[str, float]]]:
|
| 192 |
results: Dict[str, List[Tuple[str, float]]] = {}
|
| 193 |
+
# GloVe имеет другой API
|
| 194 |
+
if GLOVE_AVAILABLE and hasattr(model, 'word_vectors') and hasattr(model, 'dictionary'):
|
| 195 |
+
# GloVe модель - вычисляем ближайших соседей вручную
|
| 196 |
+
for w in test_words:
|
| 197 |
+
try:
|
| 198 |
+
if w in model.dictionary:
|
| 199 |
+
vec_w = model.word_vectors[model.dictionary[w]]
|
| 200 |
+
similarities = []
|
| 201 |
+
for word, idx in model.dictionary.items():
|
| 202 |
+
if word != w:
|
| 203 |
+
vec = model.word_vectors[idx]
|
| 204 |
+
sim = float(np.dot(vec_w, vec) / (np.linalg.norm(vec_w) * np.linalg.norm(vec)))
|
| 205 |
+
similarities.append((word, sim))
|
| 206 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 207 |
+
results[w] = similarities[:topn]
|
| 208 |
+
else:
|
| 209 |
+
results[w] = []
|
| 210 |
+
except:
|
| 211 |
+
results[w] = []
|
| 212 |
+
else:
|
| 213 |
+
# Gensim модели (Word2Vec, FastText, Doc2Vec)
|
| 214 |
+
kv = model.wv if hasattr(model, "wv") else model
|
| 215 |
+
for w in test_words:
|
| 216 |
+
if w in kv:
|
| 217 |
+
results[w] = kv.most_similar(w, topn=topn)
|
| 218 |
+
else:
|
| 219 |
+
results[w] = []
|
| 220 |
return results
|
| 221 |
|
| 222 |
|
| 223 |
def cosine_similarity(model, word_pairs: List[Tuple[str, str]]) -> List[Tuple[str, str, float]]:
|
| 224 |
out: List[Tuple[str, str, float]] = []
|
| 225 |
+
# GloVe имеет другой API
|
| 226 |
+
if GLOVE_AVAILABLE and hasattr(model, 'word_vectors') and hasattr(model, 'dictionary'):
|
| 227 |
+
# GloVe модель
|
| 228 |
+
for a, b in word_pairs:
|
| 229 |
+
try:
|
| 230 |
+
if a in model.dictionary and b in model.dictionary:
|
| 231 |
+
vec_a = model.word_vectors[model.dictionary[a]]
|
| 232 |
+
vec_b = model.word_vectors[model.dictionary[b]]
|
| 233 |
+
sim = float(np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b)))
|
| 234 |
+
out.append((a, b, sim))
|
| 235 |
+
else:
|
| 236 |
+
out.append((a, b, np.nan))
|
| 237 |
+
except:
|
| 238 |
+
out.append((a, b, np.nan))
|
| 239 |
+
else:
|
| 240 |
+
# Gensim модели
|
| 241 |
+
kv = model.wv if hasattr(model, "wv") else model
|
| 242 |
+
for a, b in word_pairs:
|
| 243 |
+
if a in kv and b in kv:
|
| 244 |
+
out.append((a, b, float(kv.similarity(a, b))))
|
| 245 |
+
else:
|
| 246 |
+
out.append((a, b, np.nan))
|
| 247 |
return out
|
| 248 |
|
| 249 |
|
| 250 |
def word_analogy(model, a: str, b: str, c: str, topn: int = 10) -> List[Tuple[str, float]]:
|
| 251 |
+
# GloVe не имеет встроенного метода для аналогий, вычисляем вручную
|
| 252 |
+
if GLOVE_AVAILABLE and hasattr(model, 'word_vectors') and hasattr(model, 'dictionary'):
|
| 253 |
+
# GloVe модель - вычисляем аналогию вручную
|
| 254 |
+
try:
|
| 255 |
+
if all(token in model.dictionary for token in [a, b, c]):
|
| 256 |
+
vec_a = model.word_vectors[model.dictionary[a]]
|
| 257 |
+
vec_b = model.word_vectors[model.dictionary[b]]
|
| 258 |
+
vec_c = model.word_vectors[model.dictionary[c]]
|
| 259 |
+
target = vec_b - vec_a + vec_c
|
| 260 |
+
# Находим ближайшие векторы
|
| 261 |
+
similarities = []
|
| 262 |
+
for word, idx in model.dictionary.items():
|
| 263 |
+
if word not in [a, b, c]:
|
| 264 |
+
vec = model.word_vectors[idx]
|
| 265 |
+
sim = float(np.dot(target, vec) / (np.linalg.norm(target) * np.linalg.norm(vec)))
|
| 266 |
+
similarities.append((word, sim))
|
| 267 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 268 |
+
return similarities[:topn]
|
| 269 |
+
except:
|
| 270 |
+
pass
|
| 271 |
+
return []
|
| 272 |
+
else:
|
| 273 |
+
# Gensim модели
|
| 274 |
+
kv = model.wv if hasattr(model, "wv") else model
|
| 275 |
+
if all(token in kv for token in [a, b, c]):
|
| 276 |
+
return kv.most_similar(positive=[b, c], negative=[a], topn=topn)
|
| 277 |
+
return []
|
| 278 |
|
| 279 |
|
| 280 |
def export_training_report(cfg: TrainConfig, train_time: float, model_path: str, extra: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
|
src/imbalance_handling.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Методы борьбы с дисбалансом классов в текстовых данных:
|
| 3 |
+
взвешивание классов, сэмплирование, аугментация текстов.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import List, Tuple, Dict, Any, Optional
|
| 9 |
+
from collections import Counter
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from sklearn.utils import resample
|
| 13 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from imblearn.over_sampling import SMOTE, ADASYN, RandomOverSampler
|
| 17 |
+
from imblearn.under_sampling import RandomUnderSampler
|
| 18 |
+
IMBLEARN_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
IMBLEARN_AVAILABLE = False
|
| 21 |
+
print("⚠️ imbalanced-learn не установлен. SMOTE/ADASYN недоступны.")
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import nlpaug.augmenter.word as naw
|
| 25 |
+
NLPAUG_AVAILABLE = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
NLPAUG_AVAILABLE = False
|
| 28 |
+
print("⚠️ nlpaug не установлен. Аугментация текстов недоступна.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_class_weights(y: np.ndarray, method: str = "balanced") -> Dict[int, float]:
|
| 32 |
+
"""
|
| 33 |
+
Вычисляет веса классов.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
y: Массив меток
|
| 37 |
+
method: Метод вычисления весов ('balanced', 'balanced_subsample', или dict)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Словарь {класс: вес}
|
| 41 |
+
"""
|
| 42 |
+
classes = np.unique(y)
|
| 43 |
+
weights = compute_class_weight(method, classes=classes, y=y)
|
| 44 |
+
return dict(zip(classes, weights))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def random_oversample(X: np.ndarray, y: np.ndarray,
|
| 48 |
+
strategy: Optional[Dict[int, int]] = None) -> Tuple[np.ndarray, np.ndarray]:
|
| 49 |
+
"""
|
| 50 |
+
Случайная перевыборка миноритарных классов.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
X: Признаки
|
| 54 |
+
y: Метки
|
| 55 |
+
strategy: Словарь {класс: целевое количество} или None для балансировки
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Перевыбранные X, y
|
| 59 |
+
"""
|
| 60 |
+
if strategy is None:
|
| 61 |
+
# Балансируем до максимального класса
|
| 62 |
+
class_counts = Counter(y)
|
| 63 |
+
max_count = max(class_counts.values())
|
| 64 |
+
strategy = {cls: max_count for cls in class_counts.keys()}
|
| 65 |
+
|
| 66 |
+
X_resampled = []
|
| 67 |
+
y_resampled = []
|
| 68 |
+
|
| 69 |
+
for cls in strategy.keys():
|
| 70 |
+
mask = y == cls
|
| 71 |
+
X_cls = X[mask]
|
| 72 |
+
y_cls = y[mask]
|
| 73 |
+
|
| 74 |
+
if len(X_cls) < strategy[cls]:
|
| 75 |
+
# Перевыборка
|
| 76 |
+
X_cls_resampled, y_cls_resampled = resample(
|
| 77 |
+
X_cls, y_cls,
|
| 78 |
+
n_samples=strategy[cls],
|
| 79 |
+
random_state=42
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
X_cls_resampled, y_cls_resampled = X_cls, y_cls
|
| 83 |
+
|
| 84 |
+
X_resampled.append(X_cls_resampled)
|
| 85 |
+
y_resampled.append(y_cls_resampled)
|
| 86 |
+
|
| 87 |
+
X_resampled = np.vstack(X_resampled)
|
| 88 |
+
y_resampled = np.hstack(y_resampled)
|
| 89 |
+
|
| 90 |
+
# Перемешивание
|
| 91 |
+
indices = np.random.permutation(len(X_resampled))
|
| 92 |
+
return X_resampled[indices], y_resampled[indices]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def random_undersample(X: np.ndarray, y: np.ndarray,
|
| 96 |
+
strategy: Optional[Dict[int, int]] = None) -> Tuple[np.ndarray, np.ndarray]:
|
| 97 |
+
"""
|
| 98 |
+
Случайная недо-выборка мажоритарных классов.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
X: Признаки
|
| 102 |
+
y: Метки
|
| 103 |
+
strategy: Словарь {класс: целевое количество} или None для балансировки
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Недо-выбранные X, y
|
| 107 |
+
"""
|
| 108 |
+
if strategy is None:
|
| 109 |
+
# Балансируем до минимального класса
|
| 110 |
+
class_counts = Counter(y)
|
| 111 |
+
min_count = min(class_counts.values())
|
| 112 |
+
strategy = {cls: min_count for cls in class_counts.keys()}
|
| 113 |
+
|
| 114 |
+
X_resampled = []
|
| 115 |
+
y_resampled = []
|
| 116 |
+
|
| 117 |
+
for cls in strategy.keys():
|
| 118 |
+
mask = y == cls
|
| 119 |
+
X_cls = X[mask]
|
| 120 |
+
y_cls = y[mask]
|
| 121 |
+
|
| 122 |
+
if len(X_cls) > strategy[cls]:
|
| 123 |
+
# Недо-выборка
|
| 124 |
+
X_cls_resampled, y_cls_resampled = resample(
|
| 125 |
+
X_cls, y_cls,
|
| 126 |
+
n_samples=strategy[cls],
|
| 127 |
+
random_state=42
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
X_cls_resampled, y_cls_resampled = X_cls, y_cls
|
| 131 |
+
|
| 132 |
+
X_resampled.append(X_cls_resampled)
|
| 133 |
+
y_resampled.append(y_cls_resampled)
|
| 134 |
+
|
| 135 |
+
X_resampled = np.vstack(X_resampled)
|
| 136 |
+
y_resampled = np.hstack(y_resampled)
|
| 137 |
+
|
| 138 |
+
# Перемешивание
|
| 139 |
+
indices = np.random.permutation(len(X_resampled))
|
| 140 |
+
return X_resampled[indices], y_resampled[indices]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def smote_oversample(X: np.ndarray, y: np.ndarray,
|
| 144 |
+
k_neighbors: int = 5) -> Tuple[np.ndarray, np.ndarray]:
|
| 145 |
+
"""
|
| 146 |
+
SMOTE (Synthetic Minority Oversampling Technique) для векторизованных текстов.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
X: Векториз��ванные признаки
|
| 150 |
+
y: Метки
|
| 151 |
+
k_neighbors: Количество соседей для SMOTE
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Перевыбранные X, y
|
| 155 |
+
"""
|
| 156 |
+
if not IMBLEARN_AVAILABLE:
|
| 157 |
+
raise ImportError("imbalanced-learn не установлен. Установите: pip install imbalanced-learn")
|
| 158 |
+
|
| 159 |
+
smote = SMOTE(k_neighbors=k_neighbors, random_state=42)
|
| 160 |
+
X_resampled, y_resampled = smote.fit_resample(X, y)
|
| 161 |
+
return X_resampled, y_resampled
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def adasyn_oversample(X: np.ndarray, y: np.ndarray,
|
| 165 |
+
n_neighbors: int = 5) -> Tuple[np.ndarray, np.ndarray]:
|
| 166 |
+
"""
|
| 167 |
+
ADASYN (Adaptive Synthetic Sampling) для векторизованных текстов.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
X: Векторизованные признаки
|
| 171 |
+
y: Метки
|
| 172 |
+
n_neighbors: Количество соседей для ADASYN
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Перевыбранные X, y
|
| 176 |
+
"""
|
| 177 |
+
if not IMBLEARN_AVAILABLE:
|
| 178 |
+
raise ImportError("imbalanced-learn не установлен. Установите: pip install imbalanced-learn")
|
| 179 |
+
|
| 180 |
+
adasyn = ADASYN(n_neighbors=n_neighbors, random_state=42)
|
| 181 |
+
X_resampled, y_resampled = adasyn.fit_resample(X, y)
|
| 182 |
+
return X_resampled, y_resampled
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def synonym_replacement(text: str, num_replacements: int = 1) -> str:
|
| 186 |
+
"""
|
| 187 |
+
Замена слов синонимами (упрощенная версия).
|
| 188 |
+
|
| 189 |
+
Примечание: Для полноценной работы требуется словарь синонимов или WordNet.
|
| 190 |
+
"""
|
| 191 |
+
# Упрощенная версия - просто возвращаем исходный текст
|
| 192 |
+
# Для реальной работы нужен словарь синонимов или библиотека типа pymorphy2 + словари
|
| 193 |
+
return text
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def random_deletion(text: str, p: float = 0.1) -> str:
|
| 197 |
+
"""
|
| 198 |
+
Случайное удаление слов из текста.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
text: Исходный текст
|
| 202 |
+
p: Вероятность удаления каждого слова
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Текст с удаленными словами
|
| 206 |
+
"""
|
| 207 |
+
words = text.split()
|
| 208 |
+
if len(words) == 0:
|
| 209 |
+
return text
|
| 210 |
+
|
| 211 |
+
# Удаляем слова с вероятностью p
|
| 212 |
+
kept_words = [w for w in words if np.random.random() > p]
|
| 213 |
+
|
| 214 |
+
if len(kept_words) == 0:
|
| 215 |
+
# Если все слова удалены, возвращаем одно случайное слово
|
| 216 |
+
return np.random.choice(words)
|
| 217 |
+
|
| 218 |
+
return ' '.join(kept_words)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def random_insertion(text: str, num_insertions: int = 1) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Случайная вставка слов в текст (упрощенная версия).
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
text: Исходный текст
|
| 227 |
+
num_insertions: Количество вставок
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Текст с вставленными словами
|
| 231 |
+
"""
|
| 232 |
+
words = text.split()
|
| 233 |
+
if len(words) == 0:
|
| 234 |
+
return text
|
| 235 |
+
|
| 236 |
+
for _ in range(num_insertions):
|
| 237 |
+
# Вставляем случайное слово в случайную позицию
|
| 238 |
+
random_word = np.random.choice(words)
|
| 239 |
+
random_pos = np.random.randint(0, len(words) + 1)
|
| 240 |
+
words.insert(random_pos, random_word)
|
| 241 |
+
|
| 242 |
+
return ' '.join(words)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def random_swap(text: str, num_swaps: int = 1) -> str:
|
| 246 |
+
"""
|
| 247 |
+
Случайная перестановка слов в тексте.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
text: Исходный текст
|
| 251 |
+
num_swaps: Количество перестановок
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Текст с переставленными словами
|
| 255 |
+
"""
|
| 256 |
+
words = text.split()
|
| 257 |
+
if len(words) < 2:
|
| 258 |
+
return text
|
| 259 |
+
|
| 260 |
+
for _ in range(num_swaps):
|
| 261 |
+
idx1, idx2 = np.random.choice(len(words), size=2, replace=False)
|
| 262 |
+
words[idx1], words[idx2] = words[idx2], words[idx1]
|
| 263 |
+
|
| 264 |
+
return ' '.join(words)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def easy_data_augmentation(text: str,
|
| 268 |
+
alpha_sr: float = 0.1,
|
| 269 |
+
alpha_ri: float = 0.1,
|
| 270 |
+
alpha_rs: float = 0.1,
|
| 271 |
+
num_aug: int = 1) -> List[str]:
|
| 272 |
+
"""
|
| 273 |
+
Easy Data Augmentation (EDA) для текста.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
text: Исходный текст
|
| 277 |
+
alpha_sr: Параметр для synonym replacement
|
| 278 |
+
alpha_ri: Параметр для random insertion
|
| 279 |
+
alpha_rs: Параметр для random swap
|
| 280 |
+
num_aug: Количество аугментированных вариантов
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Список аугментированных текстов
|
| 284 |
+
"""
|
| 285 |
+
num_words = len(text.split())
|
| 286 |
+
augmented_texts = []
|
| 287 |
+
|
| 288 |
+
for _ in range(num_aug):
|
| 289 |
+
augmented = text
|
| 290 |
+
|
| 291 |
+
# Synonym replacement
|
| 292 |
+
if np.random.random() < alpha_sr:
|
| 293 |
+
augmented = synonym_replacement(augmented)
|
| 294 |
+
|
| 295 |
+
# Random insertion
|
| 296 |
+
if np.random.random() < alpha_ri:
|
| 297 |
+
n_insert = max(1, int(alpha_ri * num_words))
|
| 298 |
+
augmented = random_insertion(augmented, n_insert)
|
| 299 |
+
|
| 300 |
+
# Random swap
|
| 301 |
+
if np.random.random() < alpha_rs:
|
| 302 |
+
n_swap = max(1, int(alpha_rs * num_words))
|
| 303 |
+
augmented = random_swap(augmented, n_swap)
|
| 304 |
+
|
| 305 |
+
# Random deletion
|
| 306 |
+
if np.random.random() < alpha_sr:
|
| 307 |
+
augmented = random_deletion(augmented, alpha_sr)
|
| 308 |
+
|
| 309 |
+
augmented_texts.append(augmented)
|
| 310 |
+
|
| 311 |
+
return augmented_texts
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def augment_texts(texts: List[str], labels: List[int],
|
| 315 |
+
target_class: Optional[int] = None,
|
| 316 |
+
num_aug: int = 1,
|
| 317 |
+
method: str = "eda") -> Tuple[List[str], List[int]]:
|
| 318 |
+
"""
|
| 319 |
+
Аугментация текстов для балансировки классов.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
texts: Список текстов
|
| 323 |
+
labels: Список меток
|
| 324 |
+
target_class: Класс для аугментации (None = все миноритарные)
|
| 325 |
+
num_aug: Количество аугментированных вариантов на текст
|
| 326 |
+
method: Метод аугментации ('eda', 'nlpaug')
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Расширенные списки текстов и меток
|
| 330 |
+
"""
|
| 331 |
+
augmented_texts = list(texts)
|
| 332 |
+
augmented_labels = list(labels)
|
| 333 |
+
|
| 334 |
+
if target_class is None:
|
| 335 |
+
# Определяем миноритарные классы
|
| 336 |
+
class_counts = Counter(labels)
|
| 337 |
+
min_count = min(class_counts.values())
|
| 338 |
+
target_classes = [cls for cls, count in class_counts.items() if count == min_count]
|
| 339 |
+
else:
|
| 340 |
+
target_classes = [target_class]
|
| 341 |
+
|
| 342 |
+
for cls in target_classes:
|
| 343 |
+
cls_texts = [text for text, label in zip(texts, labels) if label == cls]
|
| 344 |
+
|
| 345 |
+
for text in cls_texts:
|
| 346 |
+
if method == "eda":
|
| 347 |
+
aug_texts = easy_data_augmentation(text, num_aug=num_aug)
|
| 348 |
+
elif method == "nlpaug" and NLPAUG_AVAILABLE:
|
| 349 |
+
# Использование nlpaug (требует настройки)
|
| 350 |
+
aug_texts = [text] # Заглушка
|
| 351 |
+
else:
|
| 352 |
+
aug_texts = [text]
|
| 353 |
+
|
| 354 |
+
augmented_texts.extend(aug_texts)
|
| 355 |
+
augmented_labels.extend([cls] * len(aug_texts))
|
| 356 |
+
|
| 357 |
+
return augmented_texts, augmented_labels
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
# Тестирование
|
| 362 |
+
import numpy as np
|
| 363 |
+
|
| 364 |
+
# Создаем несбалансированные данные
|
| 365 |
+
X = np.random.randn(100, 50)
|
| 366 |
+
y = np.array([0] * 80 + [1] * 20)
|
| 367 |
+
|
| 368 |
+
print(f"Исходное распределение: {Counter(y)}")
|
| 369 |
+
|
| 370 |
+
# Перевыборка
|
| 371 |
+
X_resampled, y_resampled = random_oversample(X, y)
|
| 372 |
+
print(f"После перевыборки: {Counter(y_resampled)}")
|
| 373 |
+
|
| 374 |
+
# SMOTE (если доступен)
|
| 375 |
+
if IMBLEARN_AVAILABLE:
|
| 376 |
+
X_smote, y_smote = smote_oversample(X, y)
|
| 377 |
+
print(f"После SMOTE: {Counter(y_smote)}")
|
| 378 |
+
|
| 379 |
+
# Аугментация текстов
|
| 380 |
+
texts = ["Это тестовый текст", "Другой пример текста"] * 50
|
| 381 |
+
labels = [0] * 80 + [1] * 20
|
| 382 |
+
|
| 383 |
+
aug_texts, aug_labels = augment_texts(texts, labels, num_aug=2)
|
| 384 |
+
print(f"После аугментации: {len(aug_texts)} текстов, распределение: {Counter(aug_labels)}")
|
| 385 |
+
|
src/model_evaluation.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Модуль для оценки качества моделей классификации и настройки гиперпараметров.
|
| 3 |
+
Включает кросс-валидацию, подбор гиперпараметров и комплексные метрики.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from sklearn.model_selection import (
|
| 15 |
+
GridSearchCV, RandomizedSearchCV, StratifiedKFold,
|
| 16 |
+
cross_val_score, train_test_split
|
| 17 |
+
)
|
| 18 |
+
from sklearn.metrics import (
|
| 19 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 20 |
+
roc_auc_score, classification_report, confusion_matrix,
|
| 21 |
+
precision_recall_curve, roc_curve, average_precision_score
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import optuna
|
| 26 |
+
OPTUNA_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
OPTUNA_AVAILABLE = False
|
| 29 |
+
print("⚠️ Optuna не установлен. Bayesian optimization недоступен.")
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
|
| 33 |
+
HYPEROPT_AVAILABLE = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
HYPEROPT_AVAILABLE = False
|
| 36 |
+
print("⚠️ Hyperopt не установлен. Bayesian optimization недоступен.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class EvaluationMetrics:
|
| 41 |
+
"""Контейнер для метрик оценки."""
|
| 42 |
+
accuracy: float
|
| 43 |
+
precision_macro: float
|
| 44 |
+
recall_macro: float
|
| 45 |
+
f1_macro: float
|
| 46 |
+
precision_micro: float
|
| 47 |
+
recall_micro: float
|
| 48 |
+
f1_micro: float
|
| 49 |
+
roc_auc: Optional[float] = None
|
| 50 |
+
pr_auc: Optional[float] = None
|
| 51 |
+
train_time: float = 0.0
|
| 52 |
+
predict_time: float = 0.0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def evaluate_classifier(y_true: np.ndarray,
|
| 56 |
+
y_pred: np.ndarray,
|
| 57 |
+
y_proba: Optional[np.ndarray] = None,
|
| 58 |
+
task_type: str = "multiclass") -> EvaluationMetrics:
|
| 59 |
+
"""
|
| 60 |
+
Комплексная оценка классификатора.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
y_true: Истинные метки
|
| 64 |
+
y_pred: Предсказанные метки
|
| 65 |
+
y_proba: Вероятности классов
|
| 66 |
+
task_type: Тип задачи (binary, multiclass, multilabel)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Объект EvaluationMetrics
|
| 70 |
+
"""
|
| 71 |
+
metrics = EvaluationMetrics(
|
| 72 |
+
accuracy=accuracy_score(y_true, y_pred),
|
| 73 |
+
precision_macro=precision_score(y_true, y_pred, average='macro', zero_division=0),
|
| 74 |
+
recall_macro=recall_score(y_true, y_pred, average='macro', zero_division=0),
|
| 75 |
+
f1_macro=f1_score(y_true, y_pred, average='macro', zero_division=0),
|
| 76 |
+
precision_micro=precision_score(y_true, y_pred, average='micro', zero_division=0),
|
| 77 |
+
recall_micro=recall_score(y_true, y_pred, average='micro', zero_division=0),
|
| 78 |
+
f1_micro=f1_score(y_true, y_pred, average='micro', zero_division=0),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# ROC-AUC для бинарной классификации
|
| 82 |
+
if task_type == "binary" and y_proba is not None:
|
| 83 |
+
if y_proba.shape[1] == 2:
|
| 84 |
+
try:
|
| 85 |
+
metrics.roc_auc = roc_auc_score(y_true, y_proba[:, 1])
|
| 86 |
+
metrics.pr_auc = average_precision_score(y_true, y_proba[:, 1])
|
| 87 |
+
except:
|
| 88 |
+
pass
|
| 89 |
+
elif y_proba.shape[1] == 1:
|
| 90 |
+
try:
|
| 91 |
+
metrics.roc_auc = roc_auc_score(y_true, y_proba.flatten())
|
| 92 |
+
metrics.pr_auc = average_precision_score(y_true, y_proba.flatten())
|
| 93 |
+
except:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# ROC-AUC для многоклассовой (macro)
|
| 97 |
+
elif task_type == "multiclass" and y_proba is not None:
|
| 98 |
+
try:
|
| 99 |
+
metrics.roc_auc = roc_auc_score(y_true, y_proba, average='macro', multi_class='ovr')
|
| 100 |
+
except:
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
return metrics
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def cross_validate(model, X: np.ndarray, y: np.ndarray,
|
| 107 |
+
cv: int = 5,
|
| 108 |
+
scoring: str = 'f1_macro',
|
| 109 |
+
return_train_score: bool = False) -> Dict[str, Any]:
|
| 110 |
+
"""
|
| 111 |
+
Кросс-валидация модели.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
model: Модель с интерфейсом sklearn
|
| 115 |
+
X: Признаки
|
| 116 |
+
y: Метки
|
| 117 |
+
cv: Количество фолдов
|
| 118 |
+
scoring: Метрика для оценки
|
| 119 |
+
return_train_score: Возвращать ли оценки на обучении
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Словарь с результатами кросс-валидации
|
| 123 |
+
"""
|
| 124 |
+
cv_scores = cross_val_score(
|
| 125 |
+
model, X, y,
|
| 126 |
+
cv=StratifiedKFold(n_splits=cv, shuffle=True, random_state=42),
|
| 127 |
+
scoring=scoring,
|
| 128 |
+
return_train_score=return_train_score
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
result = {
|
| 132 |
+
"mean": float(cv_scores.mean()),
|
| 133 |
+
"std": float(cv_scores.std()),
|
| 134 |
+
"scores": cv_scores.tolist()
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if return_train_score and hasattr(cv_scores, 'train_scores'):
|
| 138 |
+
result["train_mean"] = float(cv_scores.train_scores.mean())
|
| 139 |
+
result["train_std"] = float(cv_scores.train_scores.std())
|
| 140 |
+
|
| 141 |
+
return result
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def grid_search(model, X: np.ndarray, y: np.ndarray,
|
| 145 |
+
param_grid: Dict[str, List[Any]],
|
| 146 |
+
cv: int = 5,
|
| 147 |
+
scoring: str = 'f1_macro',
|
| 148 |
+
n_jobs: int = -1) -> Dict[str, Any]:
|
| 149 |
+
"""
|
| 150 |
+
Подбор гиперпараметров методом Grid Search.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
model: Модель с интерфейсом sklearn
|
| 154 |
+
X: Признаки
|
| 155 |
+
y: Метки
|
| 156 |
+
param_grid: Сетка параметров
|
| 157 |
+
cv: Количество фолдов
|
| 158 |
+
scoring: Метрика для оценки
|
| 159 |
+
n_jobs: Количество параллельных задач
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Словарь с лучшими параметрами и результатами
|
| 163 |
+
"""
|
| 164 |
+
grid_search = GridSearchCV(
|
| 165 |
+
model,
|
| 166 |
+
param_grid,
|
| 167 |
+
cv=StratifiedKFold(n_splits=cv, shuffle=True, random_state=42),
|
| 168 |
+
scoring=scoring,
|
| 169 |
+
n_jobs=n_jobs,
|
| 170 |
+
verbose=1
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
start = time.time()
|
| 174 |
+
grid_search.fit(X, y)
|
| 175 |
+
search_time = time.time() - start
|
| 176 |
+
|
| 177 |
+
return {
|
| 178 |
+
"best_params": grid_search.best_params_,
|
| 179 |
+
"best_score": float(grid_search.best_score_),
|
| 180 |
+
"best_model": grid_search.best_estimator_,
|
| 181 |
+
"search_time": search_time,
|
| 182 |
+
"cv_results": grid_search.cv_results_
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def random_search(model, X: np.ndarray, y: np.ndarray,
|
| 187 |
+
param_distributions: Dict[str, List[Any]],
|
| 188 |
+
n_iter: int = 50,
|
| 189 |
+
cv: int = 5,
|
| 190 |
+
scoring: str = 'f1_macro',
|
| 191 |
+
n_jobs: int = -1) -> Dict[str, Any]:
|
| 192 |
+
"""
|
| 193 |
+
Подбор гиперпараметров методом Random Search.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
model: Модель с интерфейсом sklearn
|
| 197 |
+
X: Признаки
|
| 198 |
+
y: Метки
|
| 199 |
+
param_distributions: Распределения параметров
|
| 200 |
+
n_iter: Количество итераций
|
| 201 |
+
cv: Количество фолдов
|
| 202 |
+
scoring: Метрика для оценки
|
| 203 |
+
n_jobs: Количество параллельных задач
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Словарь с лучшими параметрами и результатами
|
| 207 |
+
"""
|
| 208 |
+
random_search = RandomizedSearchCV(
|
| 209 |
+
model,
|
| 210 |
+
param_distributions,
|
| 211 |
+
n_iter=n_iter,
|
| 212 |
+
cv=StratifiedKFold(n_splits=cv, shuffle=True, random_state=42),
|
| 213 |
+
scoring=scoring,
|
| 214 |
+
n_jobs=n_jobs,
|
| 215 |
+
random_state=42,
|
| 216 |
+
verbose=1
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
start = time.time()
|
| 220 |
+
random_search.fit(X, y)
|
| 221 |
+
search_time = time.time() - start
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"best_params": random_search.best_params_,
|
| 225 |
+
"best_score": float(random_search.best_score_),
|
| 226 |
+
"best_model": random_search.best_estimator_,
|
| 227 |
+
"search_time": search_time,
|
| 228 |
+
"cv_results": random_search.cv_results_
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def optuna_optimize(model_class, X: np.ndarray, y: np.ndarray,
|
| 233 |
+
param_space: Dict[str, Any],
|
| 234 |
+
n_trials: int = 50,
|
| 235 |
+
cv: int = 5,
|
| 236 |
+
scoring: str = 'f1_macro') -> Dict[str, Any]:
|
| 237 |
+
"""
|
| 238 |
+
Подбор гиперпараметров методом Bayesian Optimization (Optuna).
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
model_class: Класс модели
|
| 242 |
+
X: Признаки
|
| 243 |
+
y: Метки
|
| 244 |
+
param_space: Пространство параметров (функции для Optuna)
|
| 245 |
+
n_trials: Количество испытаний
|
| 246 |
+
cv: Количество фолдов
|
| 247 |
+
scoring: Метрика для оценки
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Словарь с лучшими параметрами и результатами
|
| 251 |
+
"""
|
| 252 |
+
if not OPTUNA_AVAILABLE:
|
| 253 |
+
raise ImportError("Optuna не установлен. Установите: pip install optuna")
|
| 254 |
+
|
| 255 |
+
def objective(trial):
|
| 256 |
+
params = {}
|
| 257 |
+
for param_name, param_func in param_space.items():
|
| 258 |
+
params[param_name] = param_func(trial)
|
| 259 |
+
|
| 260 |
+
model = model_class(**params)
|
| 261 |
+
scores = cross_val_score(
|
| 262 |
+
model, X, y,
|
| 263 |
+
cv=StratifiedKFold(n_splits=cv, shuffle=True, random_state=42),
|
| 264 |
+
scoring=scoring
|
| 265 |
+
)
|
| 266 |
+
return scores.mean()
|
| 267 |
+
|
| 268 |
+
study = optuna.create_study(direction='maximize', study_name='classifier_optimization')
|
| 269 |
+
start = time.time()
|
| 270 |
+
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
| 271 |
+
search_time = time.time() - start
|
| 272 |
+
|
| 273 |
+
# Обучаем лучшую модель
|
| 274 |
+
best_model = model_class(**study.best_params)
|
| 275 |
+
best_model.fit(X, y)
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
"best_params": study.best_params,
|
| 279 |
+
"best_score": float(study.best_value),
|
| 280 |
+
"best_model": best_model,
|
| 281 |
+
"search_time": search_time,
|
| 282 |
+
"study": study
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def create_confusion_matrix_plot(y_true: np.ndarray, y_pred: np.ndarray,
|
| 287 |
+
class_names: Optional[List[str]] = None) -> pd.DataFrame:
|
| 288 |
+
"""
|
| 289 |
+
Создает матрицу ошибок.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
y_true: Истинные метки
|
| 293 |
+
y_pred: Предсказанные метки
|
| 294 |
+
class_names: Названия классов
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
DataFrame с матрицей ошибок
|
| 298 |
+
"""
|
| 299 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 300 |
+
|
| 301 |
+
if class_names is None:
|
| 302 |
+
class_names = [f"Класс {i}" for i in range(len(cm))]
|
| 303 |
+
|
| 304 |
+
df = pd.DataFrame(cm, index=class_names, columns=class_names)
|
| 305 |
+
return df
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def create_classification_report_df(y_true: np.ndarray, y_pred: np.ndarray,
|
| 309 |
+
class_names: Optional[List[str]] = None) -> pd.DataFrame:
|
| 310 |
+
"""
|
| 311 |
+
Создает отчет о классификации.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
y_true: Истинные метки
|
| 315 |
+
y_pred: Предсказанные метки
|
| 316 |
+
class_names: Названия классов
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
DataFrame с отчетом
|
| 320 |
+
"""
|
| 321 |
+
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
| 322 |
+
df = pd.DataFrame(report).transpose()
|
| 323 |
+
return df
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
# Тестирование
|
| 328 |
+
from sklearn.datasets import make_classification
|
| 329 |
+
from sklearn.linear_model import LogisticRegression
|
| 330 |
+
|
| 331 |
+
X, y = make_classification(n_samples=1000, n_features=20, n_classes=3, random_state=42)
|
| 332 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 333 |
+
|
| 334 |
+
# Обучение модели
|
| 335 |
+
model = LogisticRegression(max_iter=1000, random_state=42)
|
| 336 |
+
model.fit(X_train, y_train)
|
| 337 |
+
|
| 338 |
+
# Оценка
|
| 339 |
+
y_pred = model.predict(X_test)
|
| 340 |
+
y_proba = model.predict_proba(X_test)
|
| 341 |
+
|
| 342 |
+
metrics = evaluate_classifier(y_test, y_pred, y_proba, task_type="multiclass")
|
| 343 |
+
print("Метрики:")
|
| 344 |
+
print(f"Accuracy: {metrics.accuracy:.4f}")
|
| 345 |
+
print(f"F1 (macro): {metrics.f1_macro:.4f}")
|
| 346 |
+
print(f"ROC-AUC: {metrics.roc_auc:.4f if metrics.roc_auc else 'N/A'}")
|
| 347 |
+
|
| 348 |
+
# Кросс-валидация
|
| 349 |
+
cv_results = cross_validate(model, X_train, y_train, cv=5)
|
| 350 |
+
print(f"\nКросс-валидация F1: {cv_results['mean']:.4f} ± {cv_results['std']:.4f}")
|
| 351 |
+
|
| 352 |
+
# Grid Search
|
| 353 |
+
param_grid = {
|
| 354 |
+
'C': [0.1, 1, 10],
|
| 355 |
+
'penalty': ['l1', 'l2']
|
| 356 |
+
}
|
| 357 |
+
# grid_results = grid_search(model, X_train, y_train, param_grid, cv=3)
|
| 358 |
+
# print(f"\nЛучшие параметры (Grid Search): {grid_results['best_params']}")
|
| 359 |
+
|
src/model_interpretation.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Модуль для интерпретации моделей классификации: SHAP, LIME, важность признаков,
|
| 3 |
+
визуализация внимания для нейросетей и трансформеров.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import shap
|
| 14 |
+
SHAP_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
SHAP_AVAILABLE = False
|
| 17 |
+
print("⚠️ SHAP не установлен. Установите: pip install shap")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from lime import lime_text
|
| 21 |
+
from lime.lime_text import LimeTextExplainer
|
| 22 |
+
LIME_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
LIME_AVAILABLE = False
|
| 25 |
+
print("⚠️ LIME не установлен. Установите: pip install lime")
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import seaborn as sns
|
| 30 |
+
MATPLOTLIB_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
MATPLOTLIB_AVAILABLE = False
|
| 33 |
+
print("⚠️ Matplotlib не установлен. Визуализация недоступна.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_feature_importance_linear(model, feature_names: Optional[List[str]] = None) -> pd.DataFrame:
|
| 37 |
+
"""
|
| 38 |
+
Извлекает важность признаков для линейных моделей (LR, SVM).
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: Обученная модель
|
| 42 |
+
feature_names: Названия признаков
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
DataFrame с важностью признаков
|
| 46 |
+
"""
|
| 47 |
+
if hasattr(model, 'coef_'):
|
| 48 |
+
coef = model.coef_
|
| 49 |
+
if len(coef.shape) > 1:
|
| 50 |
+
# Многоклассовая классификация - берем среднее по классам
|
| 51 |
+
importance = np.abs(coef).mean(axis=0)
|
| 52 |
+
else:
|
| 53 |
+
importance = np.abs(coef)
|
| 54 |
+
|
| 55 |
+
if feature_names is None:
|
| 56 |
+
feature_names = [f"Признак {i}" for i in range(len(importance))]
|
| 57 |
+
|
| 58 |
+
df = pd.DataFrame({
|
| 59 |
+
"Признак": feature_names,
|
| 60 |
+
"Важность": importance
|
| 61 |
+
}).sort_values("Важность", ascending=False)
|
| 62 |
+
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
return pd.DataFrame()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_feature_importance_tree(model, feature_names: Optional[List[str]] = None) -> pd.DataFrame:
|
| 69 |
+
"""
|
| 70 |
+
Извлекает важность признаков для tree-based моделей (RF, XGBoost, etc.).
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
model: Обученная модель
|
| 74 |
+
feature_names: Названия признаков
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
DataFrame с важностью признаков
|
| 78 |
+
"""
|
| 79 |
+
if hasattr(model, 'feature_importances_'):
|
| 80 |
+
importance = model.feature_importances_
|
| 81 |
+
|
| 82 |
+
if feature_names is None:
|
| 83 |
+
feature_names = [f"Признак {i}" for i in range(len(importance))]
|
| 84 |
+
|
| 85 |
+
df = pd.DataFrame({
|
| 86 |
+
"Признак": feature_names,
|
| 87 |
+
"Важность": importance
|
| 88 |
+
}).sort_values("Важность", ascending=False)
|
| 89 |
+
|
| 90 |
+
return df
|
| 91 |
+
|
| 92 |
+
return pd.DataFrame()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_tfidf_important_words(vectorizer, model, class_idx: int = 0, top_k: int = 20) -> pd.DataFrame:
|
| 96 |
+
"""
|
| 97 |
+
Извлекает наиболее важные слова для TF-IDF векторизации.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
vectorizer: Обученный векторизатор
|
| 101 |
+
model: Обученная модель
|
| 102 |
+
class_idx: Индекс класса
|
| 103 |
+
top_k: Количество топ-слов
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
DataFrame с важными словами
|
| 107 |
+
"""
|
| 108 |
+
if not hasattr(model, 'coef_'):
|
| 109 |
+
return pd.DataFrame()
|
| 110 |
+
|
| 111 |
+
coef = model.coef_[class_idx] if len(model.coef_.shape) > 1 else model.coef_
|
| 112 |
+
|
| 113 |
+
if hasattr(vectorizer, 'get_feature_names_out'):
|
| 114 |
+
feature_names = vectorizer.get_feature_names_out()
|
| 115 |
+
elif hasattr(vectorizer, 'get_feature_names'):
|
| 116 |
+
feature_names = vectorizer.get_feature_names()
|
| 117 |
+
else:
|
| 118 |
+
return pd.DataFrame()
|
| 119 |
+
|
| 120 |
+
# Сортируем по важности
|
| 121 |
+
indices = np.argsort(np.abs(coef))[-top_k:][::-1]
|
| 122 |
+
|
| 123 |
+
df = pd.DataFrame({
|
| 124 |
+
"Слово": [feature_names[i] for i in indices],
|
| 125 |
+
"Коэффициент": [coef[i] for i in indices],
|
| 126 |
+
"Абсолютное значение": [np.abs(coef[i]) for i in indices]
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
return df
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def explain_with_shap(model, X: np.ndarray,
|
| 133 |
+
feature_names: Optional[List[str]] = None,
|
| 134 |
+
max_samples: int = 100) -> Optional[shap.Explanation]:
|
| 135 |
+
"""
|
| 136 |
+
Объяснение предсказаний модели с помощью SHAP.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model: Обученная модель с методом predict_proba
|
| 140 |
+
X: Признаки для объяснения
|
| 141 |
+
feature_names: Названия признаков
|
| 142 |
+
max_samples: Максимальное количество образцов для объя��нения
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
SHAP Explanation объект или None
|
| 146 |
+
"""
|
| 147 |
+
if not SHAP_AVAILABLE:
|
| 148 |
+
print("SHAP не установлен. Установите: pip install shap")
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
# Ограничиваем количество образцов для производительности
|
| 152 |
+
if len(X) > max_samples:
|
| 153 |
+
indices = np.random.choice(len(X), max_samples, replace=False)
|
| 154 |
+
X_sample = X[indices]
|
| 155 |
+
else:
|
| 156 |
+
X_sample = X
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
# Создаем explainer в зависимости от типа модели
|
| 160 |
+
if hasattr(model, 'predict_proba'):
|
| 161 |
+
explainer = shap.Explainer(model, X_sample)
|
| 162 |
+
else:
|
| 163 |
+
# Для моделей без predict_proba используем KernelExplainer
|
| 164 |
+
explainer = shap.KernelExplainer(model.predict, X_sample)
|
| 165 |
+
|
| 166 |
+
shap_values = explainer(X_sample)
|
| 167 |
+
|
| 168 |
+
if feature_names is not None:
|
| 169 |
+
shap_values.feature_names = feature_names
|
| 170 |
+
|
| 171 |
+
return shap_values
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Ошибка при создании SHAP объяснения: {e}")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def explain_with_lime_text(model, texts: List[str],
|
| 179 |
+
vectorizer: Any,
|
| 180 |
+
class_names: Optional[List[str]] = None,
|
| 181 |
+
num_features: int = 10) -> List[Dict[str, Any]]:
|
| 182 |
+
"""
|
| 183 |
+
Объяснение предсказаний модели с помощью LIME для текста.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model: Обученная модель
|
| 187 |
+
texts: Тексты для объяснения
|
| 188 |
+
vectorizer: Векторизатор текстов
|
| 189 |
+
class_names: Названия классов
|
| 190 |
+
num_features: Количество важных признаков для показа
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Список объяснений для каждого текста
|
| 194 |
+
"""
|
| 195 |
+
if not LIME_AVAILABLE:
|
| 196 |
+
print("LIME не установлен. Установите: pip install lime")
|
| 197 |
+
return []
|
| 198 |
+
|
| 199 |
+
explainer = LimeTextExplainer(class_names=class_names)
|
| 200 |
+
|
| 201 |
+
def predict_proba_wrapper(texts_list):
|
| 202 |
+
"""Обертка для predict_proba с векторизацией."""
|
| 203 |
+
X = vectorizer.transform(texts_list)
|
| 204 |
+
if hasattr(model, 'predict_proba'):
|
| 205 |
+
return model.predict_proba(X)
|
| 206 |
+
else:
|
| 207 |
+
# Для моделей без predict_proba
|
| 208 |
+
predictions = model.predict(X)
|
| 209 |
+
# Создаем псевдо-вероятности
|
| 210 |
+
proba = np.zeros((len(predictions), len(np.unique(predictions))))
|
| 211 |
+
for i, pred in enumerate(predictions):
|
| 212 |
+
proba[i, pred] = 1.0
|
| 213 |
+
return proba
|
| 214 |
+
|
| 215 |
+
explanations = []
|
| 216 |
+
for text in texts:
|
| 217 |
+
try:
|
| 218 |
+
explanation = explainer.explain_instance(
|
| 219 |
+
text,
|
| 220 |
+
predict_proba_wrapper,
|
| 221 |
+
num_features=num_features
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Извлекаем важные слова
|
| 225 |
+
exp_list = explanation.as_list()
|
| 226 |
+
explanations.append({
|
| 227 |
+
"text": text,
|
| 228 |
+
"important_words": exp_list,
|
| 229 |
+
"prediction": explanation.predict_proba.argmax() if hasattr(explanation, 'predict_proba') else None
|
| 230 |
+
})
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"Ошибка при объяснении текста: {e}")
|
| 233 |
+
explanations.append({
|
| 234 |
+
"text": text,
|
| 235 |
+
"important_words": [],
|
| 236 |
+
"prediction": None
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
return explanations
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def visualize_attention_weights(attention_weights: np.ndarray,
|
| 243 |
+
tokens: List[str],
|
| 244 |
+
save_path: Optional[str] = None) -> None:
|
| 245 |
+
"""
|
| 246 |
+
Визуализация весов внимания для трансформерных моделей.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
attention_weights: Матрица весов внимания (n_heads, seq_len, seq_len) или (seq_len, seq_len)
|
| 250 |
+
tokens: Список токенов
|
| 251 |
+
save_path: Путь для сохранения изображения
|
| 252 |
+
"""
|
| 253 |
+
if not MATPLOTLIB_AVAILABLE:
|
| 254 |
+
print("Matplotlib не установлен. Визуализация недоступна.")
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
# Если несколько голов внимания, усредняем
|
| 258 |
+
if len(attention_weights.shape) == 3:
|
| 259 |
+
attention_weights = attention_weights.mean(axis=0)
|
| 260 |
+
|
| 261 |
+
# Ограничиваем длину для визуализации
|
| 262 |
+
max_len = min(50, len(tokens))
|
| 263 |
+
attention_weights = attention_weights[:max_len, :max_len]
|
| 264 |
+
tokens = tokens[:max_len]
|
| 265 |
+
|
| 266 |
+
plt.figure(figsize=(12, 10))
|
| 267 |
+
sns.heatmap(
|
| 268 |
+
attention_weights,
|
| 269 |
+
xticklabels=tokens,
|
| 270 |
+
yticklabels=tokens,
|
| 271 |
+
cmap='Blues',
|
| 272 |
+
cbar=True
|
| 273 |
+
)
|
| 274 |
+
plt.title("Визуализация весов внимания")
|
| 275 |
+
plt.xlabel("Токены")
|
| 276 |
+
plt.ylabel("Токены")
|
| 277 |
+
plt.xticks(rotation=45, ha='right')
|
| 278 |
+
plt.yticks(rotation=0)
|
| 279 |
+
plt.tight_layout()
|
| 280 |
+
|
| 281 |
+
if save_path:
|
| 282 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 283 |
+
|
| 284 |
+
plt.show()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def analyze_error_cases(y_true: np.ndarray, y_pred: np.ndarray,
|
| 288 |
+
texts: Optional[List[str]] = None,
|
| 289 |
+
top_k: int = 10) -> pd.DataFrame:
|
| 290 |
+
"""
|
| 291 |
+
Анализ случаев, где модель ошибается.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
y_true: Истинные метки
|
| 295 |
+
y_pred: Предсказанные метки
|
| 296 |
+
texts: Тексты (опционально)
|
| 297 |
+
top_k: Количество примеров для показа
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
DataFrame с примерами ошибок
|
| 301 |
+
"""
|
| 302 |
+
errors = y_true != y_pred
|
| 303 |
+
error_indices = np.where(errors)[0]
|
| 304 |
+
|
| 305 |
+
if len(error_indices) == 0:
|
| 306 |
+
return pd.DataFrame({"Сообщение": ["Ошибок не найдено"]})
|
| 307 |
+
|
| 308 |
+
# Ограничиваем количество
|
| 309 |
+
if len(error_indices) > top_k:
|
| 310 |
+
error_indices = np.random.choice(error_indices, top_k, replace=False)
|
| 311 |
+
|
| 312 |
+
results = []
|
| 313 |
+
for idx in error_indices:
|
| 314 |
+
result = {
|
| 315 |
+
"Индекс": int(idx),
|
| 316 |
+
"Истинный класс": int(y_true[idx]),
|
| 317 |
+
"Предсказанный класс": int(y_pred[idx])
|
| 318 |
+
}
|
| 319 |
+
if texts is not None:
|
| 320 |
+
result["Текст"] = texts[idx][:200] + "..." if len(texts[idx]) > 200 else texts[idx]
|
| 321 |
+
results.append(result)
|
| 322 |
+
|
| 323 |
+
return pd.DataFrame(results)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
# Тестирование
|
| 328 |
+
from sklearn.datasets import make_classification
|
| 329 |
+
from sklearn.linear_model import LogisticRegression
|
| 330 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 331 |
+
|
| 332 |
+
# Создаем тестовые данные
|
| 333 |
+
texts = [
|
| 334 |
+
"Это положительный отзыв о продукте",
|
| 335 |
+
"Отрицательный отзыв не понравилось",
|
| 336 |
+
"Нейтральный отзыв нормально",
|
| 337 |
+
] * 10
|
| 338 |
+
|
| 339 |
+
vectorizer = TfidfVectorizer()
|
| 340 |
+
X = vectorizer.fit_transform(texts).toarray()
|
| 341 |
+
y = np.array([0, 1, 2] * 10)
|
| 342 |
+
|
| 343 |
+
# Обучение модели
|
| 344 |
+
model = LogisticRegression(max_iter=1000, random_state=42)
|
| 345 |
+
model.fit(X, y)
|
| 346 |
+
|
| 347 |
+
# Важность признаков
|
| 348 |
+
feature_importance = get_feature_importance_linear(model)
|
| 349 |
+
print("Важность признаков (топ-10):")
|
| 350 |
+
print(feature_importance.head(10))
|
| 351 |
+
|
| 352 |
+
# Важные слова для TF-IDF
|
| 353 |
+
important_words = get_tfidf_important_words(vectorizer, model, class_idx=0, top_k=10)
|
| 354 |
+
print("\nВажные слова для класса 0:")
|
| 355 |
+
print(important_words)
|
| 356 |
+
|
| 357 |
+
# SHAP (если доступен)
|
| 358 |
+
if SHAP_AVAILABLE:
|
| 359 |
+
shap_values = explain_with_shap(model, X[:5], max_samples=5)
|
| 360 |
+
if shap_values is not None:
|
| 361 |
+
print("\nSHAP объяснение создано успешно")
|
| 362 |
+
|
| 363 |
+
# LIME (если доступен)
|
| 364 |
+
if LIME_AVAILABLE:
|
| 365 |
+
lime_explanations = explain_with_lime_text(model, texts[:3], vectorizer)
|
| 366 |
+
print(f"\nLIME объяснения: {len(lime_explanations)} создано")
|
| 367 |
+
|
src/neural_classifiers.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Нейросетевые методы классификации текстов: MLP, CNN, LSTM, GRU, гибридные архитектуры.
|
| 3 |
+
Примечание: Для трансформеров (BERT, RuBERT) требуется установка transformers и torch.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
from tensorflow import keras
|
| 18 |
+
from tensorflow.keras import layers, models, callbacks
|
| 19 |
+
TENSORFLOW_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
TENSORFLOW_AVAILABLE = False
|
| 22 |
+
print("⚠️ TensorFlow не установлен. Нейросетевые модели недоступны.")
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from transformers import AutoTokenizer, AutoModel
|
| 28 |
+
TRANSFORMERS_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
TRANSFORMERS_AVAILABLE = False
|
| 31 |
+
print("⚠️ PyTorch/Transformers не установлены. Трансформерные модели недоступны.")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class NeuralConfig:
|
| 36 |
+
"""Конфигурация нейросетевой модели."""
|
| 37 |
+
model_type: str # mlp, cnn, lstm, gru, cnn_lstm, birnn_attention
|
| 38 |
+
input_dim: int
|
| 39 |
+
num_classes: int
|
| 40 |
+
embedding_dim: int = 300
|
| 41 |
+
hidden_dim: int = 128
|
| 42 |
+
dropout: float = 0.5
|
| 43 |
+
learning_rate: float = 0.001
|
| 44 |
+
epochs: int = 10
|
| 45 |
+
batch_size: int = 32
|
| 46 |
+
validation_split: float = 0.2
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class NeuralClassifiers:
|
| 50 |
+
"""Класс для работы с нейросетевыми классификаторами."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, config: NeuralConfig):
|
| 53 |
+
if not TENSORFLOW_AVAILABLE:
|
| 54 |
+
raise ImportError("TensorFlow не установлен. Установите: pip install tensorflow")
|
| 55 |
+
|
| 56 |
+
self.config = config
|
| 57 |
+
self.model = self._create_model()
|
| 58 |
+
self.history = None
|
| 59 |
+
self.train_time = 0.0
|
| 60 |
+
self.predict_time = 0.0
|
| 61 |
+
|
| 62 |
+
def _create_model(self):
|
| 63 |
+
"""Создает нейросетевую модель."""
|
| 64 |
+
model_type = self.config.model_type.lower()
|
| 65 |
+
|
| 66 |
+
if model_type == "mlp":
|
| 67 |
+
return self._create_mlp()
|
| 68 |
+
elif model_type == "cnn":
|
| 69 |
+
return self._create_cnn()
|
| 70 |
+
elif model_type == "lstm":
|
| 71 |
+
return self._create_lstm()
|
| 72 |
+
elif model_type == "gru":
|
| 73 |
+
return self._create_gru()
|
| 74 |
+
elif model_type == "cnn_lstm":
|
| 75 |
+
return self._create_cnn_lstm()
|
| 76 |
+
elif model_type == "birnn_attention":
|
| 77 |
+
return self._create_birnn_attention()
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f"Неизвестный тип модели: {model_type}")
|
| 80 |
+
|
| 81 |
+
def _create_mlp(self):
|
| 82 |
+
"""Многослойный персептрон."""
|
| 83 |
+
model = models.Sequential([
|
| 84 |
+
layers.Dense(self.config.hidden_dim, activation='relu', input_dim=self.config.input_dim),
|
| 85 |
+
layers.Dropout(self.config.dropout),
|
| 86 |
+
layers.Dense(self.config.hidden_dim // 2, activation='relu'),
|
| 87 |
+
layers.Dropout(self.config.dropout),
|
| 88 |
+
layers.Dense(self.config.num_classes, activation='softmax')
|
| 89 |
+
])
|
| 90 |
+
model.compile(
|
| 91 |
+
optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate),
|
| 92 |
+
loss='sparse_categorical_crossentropy',
|
| 93 |
+
metrics=['accuracy']
|
| 94 |
+
)
|
| 95 |
+
return model
|
| 96 |
+
|
| 97 |
+
def _create_cnn(self):
|
| 98 |
+
"""Сверточная нейросеть для текста (Kim CNN)."""
|
| 99 |
+
# Для CNN нужна последовательность токенов, поэтому используем embedding
|
| 100 |
+
# В упрощенной версии работаем с уже векторизованными данными
|
| 101 |
+
model = models.Sequential([
|
| 102 |
+
layers.Reshape((self.config.input_dim, 1), input_shape=(self.config.input_dim,)),
|
| 103 |
+
layers.Conv1D(128, 3, activation='relu'),
|
| 104 |
+
layers.MaxPooling1D(2),
|
| 105 |
+
layers.Conv1D(64, 3, activation='relu'),
|
| 106 |
+
layers.GlobalMaxPooling1D(),
|
| 107 |
+
layers.Dense(self.config.hidden_dim, activation='relu'),
|
| 108 |
+
layers.Dropout(self.config.dropout),
|
| 109 |
+
layers.Dense(self.config.num_classes, activation='softmax')
|
| 110 |
+
])
|
| 111 |
+
model.compile(
|
| 112 |
+
optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate),
|
| 113 |
+
loss='sparse_categorical_crossentropy',
|
| 114 |
+
metrics=['accuracy']
|
| 115 |
+
)
|
| 116 |
+
return model
|
| 117 |
+
|
| 118 |
+
def _create_lstm(self):
|
| 119 |
+
"""LSTM сеть."""
|
| 120 |
+
model = models.Sequential([
|
| 121 |
+
layers.Reshape((self.config.input_dim, 1), input_shape=(self.config.input_dim,)),
|
| 122 |
+
layers.LSTM(self.config.hidden_dim, return_sequences=False),
|
| 123 |
+
layers.Dropout(self.config.dropout),
|
| 124 |
+
layers.Dense(self.config.num_classes, activation='softmax')
|
| 125 |
+
])
|
| 126 |
+
model.compile(
|
| 127 |
+
optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate),
|
| 128 |
+
loss='sparse_categorical_crossentropy',
|
| 129 |
+
metrics=['accuracy']
|
| 130 |
+
)
|
| 131 |
+
return model
|
| 132 |
+
|
| 133 |
+
def _create_gru(self):
|
| 134 |
+
"""GRU сеть."""
|
| 135 |
+
model = models.Sequential([
|
| 136 |
+
layers.Reshape((self.config.input_dim, 1), input_shape=(self.config.input_dim,)),
|
| 137 |
+
layers.GRU(self.config.hidden_dim, return_sequences=False),
|
| 138 |
+
layers.Dropout(self.config.dropout),
|
| 139 |
+
layers.Dense(self.config.num_classes, activation='softmax')
|
| 140 |
+
])
|
| 141 |
+
model.compile(
|
| 142 |
+
optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate),
|
| 143 |
+
loss='sparse_categorical_crossentropy',
|
| 144 |
+
metrics=['accuracy']
|
| 145 |
+
)
|
| 146 |
+
return model
|
| 147 |
+
|
| 148 |
+
def _create_cnn_lstm(self):
|
| 149 |
+
"""Гибридная CNN + LSTM архитектура."""
|
| 150 |
+
model = models.Sequential([
|
| 151 |
+
layers.Reshape((self.config.input_dim, 1), input_shape=(self.config.input_dim,)),
|
| 152 |
+
layers.Conv1D(64, 3, activation='relu'),
|
| 153 |
+
layers.MaxPooling1D(2),
|
| 154 |
+
layers.LSTM(self.config.hidden_dim, return_sequences=False),
|
| 155 |
+
layers.Dropout(self.config.dropout),
|
| 156 |
+
layers.Dense(self.config.num_classes, activation='softmax')
|
| 157 |
+
])
|
| 158 |
+
model.compile(
|
| 159 |
+
optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate),
|
| 160 |
+
loss='sparse_categorical_crossentropy',
|
| 161 |
+
metrics=['accuracy']
|
| 162 |
+
)
|
| 163 |
+
return model
|
| 164 |
+
|
| 165 |
+
def _create_birnn_attention(self):
|
| 166 |
+
"""Двунаправленная RNN с механизмом внимания (упрощенная версия)."""
|
| 167 |
+
# Упрощенная версия без настоящего attention механизма
|
| 168 |
+
model = models.Sequential([
|
| 169 |
+
layers.Reshape((self.config.input_dim, 1), input_shape=(self.config.input_dim,)),
|
| 170 |
+
layers.Bidirectional(layers.LSTM(self.config.hidden_dim, return_sequences=True)),
|
| 171 |
+
layers.GlobalAveragePooling1D(), # Простая агрегация вместо attention
|
| 172 |
+
layers.Dropout(self.config.dropout),
|
| 173 |
+
layers.Dense(self.config.num_classes, activation='softmax')
|
| 174 |
+
])
|
| 175 |
+
model.compile(
|
| 176 |
+
optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate),
|
| 177 |
+
loss='sparse_categorical_crossentropy',
|
| 178 |
+
metrics=['accuracy']
|
| 179 |
+
)
|
| 180 |
+
return model
|
| 181 |
+
|
| 182 |
+
def fit(self, X, y, validation_data=None):
|
| 183 |
+
"""Обучение модели."""
|
| 184 |
+
if not TENSORFLOW_AVAILABLE:
|
| 185 |
+
raise ImportError("TensorFlow не установлен")
|
| 186 |
+
|
| 187 |
+
start = time.time()
|
| 188 |
+
|
| 189 |
+
callbacks_list = [
|
| 190 |
+
callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
|
| 191 |
+
callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-7)
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
if validation_data is None and self.config.validation_split > 0:
|
| 195 |
+
self.history = self.model.fit(
|
| 196 |
+
X, y,
|
| 197 |
+
epochs=self.config.epochs,
|
| 198 |
+
batch_size=self.config.batch_size,
|
| 199 |
+
validation_split=self.config.validation_split,
|
| 200 |
+
callbacks=callbacks_list,
|
| 201 |
+
verbose=1
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
self.history = self.model.fit(
|
| 205 |
+
X, y,
|
| 206 |
+
epochs=self.config.epochs,
|
| 207 |
+
batch_size=self.config.batch_size,
|
| 208 |
+
validation_data=validation_data,
|
| 209 |
+
callbacks=callbacks_list,
|
| 210 |
+
verbose=1
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.train_time = time.time() - start
|
| 214 |
+
return self
|
| 215 |
+
|
| 216 |
+
def predict(self, X):
|
| 217 |
+
"""Предсказание классов."""
|
| 218 |
+
start = time.time()
|
| 219 |
+
predictions = self.model.predict(X, verbose=0)
|
| 220 |
+
self.predict_time = time.time() - start
|
| 221 |
+
return np.argmax(predictions, axis=1)
|
| 222 |
+
|
| 223 |
+
def predict_proba(self, X):
|
| 224 |
+
"""Предсказание вероятностей."""
|
| 225 |
+
return self.model.predict(X, verbose=0)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TransformerClassifier:
|
| 229 |
+
"""
|
| 230 |
+
Классификатор на основе трансформеров (BERT, RuBERT).
|
| 231 |
+
Требует установки transformers и torch.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, model_name: str = "DeepPavlov/rubert-base-cased",
|
| 235 |
+
num_classes: int = 2,
|
| 236 |
+
max_length: int = 512,
|
| 237 |
+
learning_rate: float = 2e-5,
|
| 238 |
+
epochs: int = 3,
|
| 239 |
+
batch_size: int = 16):
|
| 240 |
+
if not TRANSFORMERS_AVAILABLE:
|
| 241 |
+
raise ImportError(
|
| 242 |
+
"PyTorch и Transformers не установлены. "
|
| 243 |
+
"Установите: pip install torch transformers"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.model_name = model_name
|
| 247 |
+
self.num_classes = num_classes
|
| 248 |
+
self.max_length = max_length
|
| 249 |
+
self.learning_rate = learning_rate
|
| 250 |
+
self.epochs = epochs
|
| 251 |
+
self.batch_size = batch_size
|
| 252 |
+
|
| 253 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 254 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 255 |
+
|
| 256 |
+
# Добавляем классификационный слой
|
| 257 |
+
self.classifier = nn.Sequential(
|
| 258 |
+
nn.Linear(self.model.config.hidden_size, 256),
|
| 259 |
+
nn.ReLU(),
|
| 260 |
+
nn.Dropout(0.3),
|
| 261 |
+
nn.Linear(256, num_classes)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 265 |
+
self.model.to(self.device)
|
| 266 |
+
self.classifier.to(self.device)
|
| 267 |
+
|
| 268 |
+
def fit(self, texts: List[str], labels: List[int]):
|
| 269 |
+
"""Обучение трансформерной модели."""
|
| 270 |
+
# Реализация обучения требует более сложной логики
|
| 271 |
+
# Здесь только заглушка
|
| 272 |
+
raise NotImplementedError(
|
| 273 |
+
"Полная реализация обучения трансформеров требует дополнительной настройки. "
|
| 274 |
+
"Рекомендуется использовать готовые решения из библиотеки transformers."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def predict(self, texts: List[str]):
|
| 278 |
+
"""Предсказание классов."""
|
| 279 |
+
raise NotImplementedError("См. fit()")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
# Тестирование (только если TensorFlow доступен)
|
| 284 |
+
if TENSORFLOW_AVAILABLE:
|
| 285 |
+
from sklearn.datasets import make_classification
|
| 286 |
+
from sklearn.model_selection import train_test_split
|
| 287 |
+
|
| 288 |
+
X, y = make_classification(n_samples=1000, n_features=100, n_classes=3, random_state=42)
|
| 289 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 290 |
+
|
| 291 |
+
config = NeuralConfig(
|
| 292 |
+
model_type="mlp",
|
| 293 |
+
input_dim=100,
|
| 294 |
+
num_classes=3,
|
| 295 |
+
epochs=5
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
classifier = NeuralClassifiers(config)
|
| 299 |
+
classifier.fit(X_train, y_train)
|
| 300 |
+
predictions = classifier.predict(X_test)
|
| 301 |
+
|
| 302 |
+
from sklearn.metrics import accuracy_score
|
| 303 |
+
print(f"Точность: {accuracy_score(y_test, predictions):.4f}")
|
| 304 |
+
else:
|
| 305 |
+
print("TensorFlow не установлен. Тесты пропущены.")
|
| 306 |
+
|
src/streamlit_app.py
CHANGED
|
@@ -13,6 +13,7 @@ from typing import List, Dict, Any, Optional
|
|
| 13 |
|
| 14 |
import streamlit as st
|
| 15 |
import pandas as pd
|
|
|
|
| 16 |
import plotly.express as px
|
| 17 |
import plotly.graph_objects as go
|
| 18 |
from plotly.subplots import make_subplots
|
|
@@ -41,6 +42,14 @@ from src.classical_vectorizers import (
|
|
| 41 |
from src.dimensionality import SVDConfig, run_lsa, embed_2d, explained_variance_table, top_terms_dataframe
|
| 42 |
from src.embeddings_train import TrainConfig as EmbTrainConfig, train_model as train_embeddings_model, save_model as save_embedding_model, evaluate_neighbors as eval_neighbors, cosine_similarity as eval_cosine, word_analogy as eval_analogy
|
| 43 |
from src.semantic_experiments import vector_arithmetic, semantic_axis, nearest_neighbors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
# Настройка страницы
|
|
@@ -317,8 +326,8 @@ def main():
|
|
| 317 |
st.session_state["processed_texts"] = processed_texts
|
| 318 |
texts = processed_texts
|
| 319 |
|
| 320 |
-
# Главные вкладки ЛР1/ЛР2
|
| 321 |
-
main_tabs = st.tabs(["Токенизация", "Векторизация", "Эмбеддинги"])
|
| 322 |
|
| 323 |
# ======== Токенизация (ЛР1) ========
|
| 324 |
with main_tabs[0]:
|
|
@@ -501,7 +510,7 @@ def main():
|
|
| 501 |
index=0, horizontal=True,
|
| 502 |
help="Предобработанные = применены настройки из блока Предобработка на левой панели"
|
| 503 |
)
|
| 504 |
-
model_type = st.selectbox("Модель", ["w2v", "fasttext", "doc2vec"], index=0)
|
| 505 |
vector_size = st.slider("Размерность", 50, 600, 300, step=50)
|
| 506 |
window = st.slider("Окно контекста", 2, 15, 8)
|
| 507 |
min_count = st.slider("Min count", 1, 20, 2)
|
|
@@ -580,6 +589,284 @@ def main():
|
|
| 580 |
if st.button("🧩 Аналогия"):
|
| 581 |
st.write(eval_analogy(model, ana_a, ana_b, ana_c, topn=10))
|
| 582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
if __name__ == "__main__":
|
| 585 |
main()
|
|
|
|
| 13 |
|
| 14 |
import streamlit as st
|
| 15 |
import pandas as pd
|
| 16 |
+
import numpy as np
|
| 17 |
import plotly.express as px
|
| 18 |
import plotly.graph_objects as go
|
| 19 |
from plotly.subplots import make_subplots
|
|
|
|
| 42 |
from src.dimensionality import SVDConfig, run_lsa, embed_2d, explained_variance_table, top_terms_dataframe
|
| 43 |
from src.embeddings_train import TrainConfig as EmbTrainConfig, train_model as train_embeddings_model, save_model as save_embedding_model, evaluate_neighbors as eval_neighbors, cosine_similarity as eval_cosine, word_analogy as eval_analogy
|
| 44 |
from src.semantic_experiments import vector_arithmetic, semantic_axis, nearest_neighbors
|
| 45 |
+
from src.text_preprocessing import TextPreprocessor, PreprocessingConfig, extract_meta_features, vectorize_with_classical, vectorize_with_embeddings
|
| 46 |
+
from src.classical_classifiers import ClassicalClassifiers, ClassifierConfig, compare_classifiers, evaluate_classifier
|
| 47 |
+
from src.neural_classifiers import NeuralClassifiers, NeuralConfig
|
| 48 |
+
from src.imbalance_handling import compute_class_weights, random_oversample, smote_oversample, augment_texts
|
| 49 |
+
from src.model_evaluation import evaluate_classifier as eval_classifier_full, cross_validate, grid_search
|
| 50 |
+
from src.model_interpretation import get_feature_importance_linear, get_tfidf_important_words, explain_with_shap, explain_with_lime_text
|
| 51 |
+
from src.text_to_vector import vectorize_texts
|
| 52 |
+
from src.clustering import ClusteringAlgorithms, ClusteringConfig, evaluate_clustering, compare_clustering_methods
|
| 53 |
|
| 54 |
|
| 55 |
# Настройка страницы
|
|
|
|
| 326 |
st.session_state["processed_texts"] = processed_texts
|
| 327 |
texts = processed_texts
|
| 328 |
|
| 329 |
+
# Главные вкладки ЛР1/ЛР2/ЛР3/ЛР4
|
| 330 |
+
main_tabs = st.tabs(["Токенизация", "Векторизация", "Эмбеддинги", "Классификация", "Кластеризация"])
|
| 331 |
|
| 332 |
# ======== Токенизация (ЛР1) ========
|
| 333 |
with main_tabs[0]:
|
|
|
|
| 510 |
index=0, horizontal=True,
|
| 511 |
help="Предобработанные = применены настройки из блока Предобработка на левой панели"
|
| 512 |
)
|
| 513 |
+
model_type = st.selectbox("Модель", ["w2v", "fasttext", "doc2vec", "glove"], index=0)
|
| 514 |
vector_size = st.slider("Размерность", 50, 600, 300, step=50)
|
| 515 |
window = st.slider("Окно контекста", 2, 15, 8)
|
| 516 |
min_count = st.slider("Min count", 1, 20, 2)
|
|
|
|
| 589 |
if st.button("🧩 Аналогия"):
|
| 590 |
st.write(eval_analogy(model, ana_a, ana_b, ana_c, topn=10))
|
| 591 |
|
| 592 |
+
# ======== Классификация (ЛР3) ========
|
| 593 |
+
with main_tabs[3]:
|
| 594 |
+
st.subheader("📊 Классификация текстов")
|
| 595 |
+
|
| 596 |
+
if not texts:
|
| 597 |
+
st.warning("⚠️ Загрузите тексты для классификации.")
|
| 598 |
+
else:
|
| 599 |
+
# Выбор типа задачи
|
| 600 |
+
task_type = st.radio(
|
| 601 |
+
"Тип задачи классификации:",
|
| 602 |
+
["Бинарная", "Многоклассовая", "Многометочная"],
|
| 603 |
+
horizontal=True
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# Создание разметки (упрощенная версия - пользователь должен разметить данные заранее)
|
| 607 |
+
st.info("💡 Для полноценной работы требуется размеченный датасет. Здесь показана демонстрация на синтетических данных.")
|
| 608 |
+
|
| 609 |
+
# Генерация синтетических меток для демонстрации
|
| 610 |
+
if "labels" not in st.session_state or st.session_state.get("task_type") != task_type:
|
| 611 |
+
if task_type == "Бинарная":
|
| 612 |
+
st.session_state["labels"] = np.random.choice([0, 1], size=len(texts))
|
| 613 |
+
elif task_type == "Многоклассовая":
|
| 614 |
+
st.session_state["labels"] = np.random.choice([0, 1, 2, 3], size=len(texts))
|
| 615 |
+
elif task_type == "Многометочная":
|
| 616 |
+
# Многометочная - создаем бинарные метки для каждой категории
|
| 617 |
+
# Каждый документ может иметь несколько меток
|
| 618 |
+
num_labels = 4
|
| 619 |
+
st.session_state["labels"] = np.random.randint(0, 2, size=(len(texts), num_labels))
|
| 620 |
+
st.session_state["num_labels"] = num_labels
|
| 621 |
+
st.session_state["task_type"] = task_type
|
| 622 |
+
|
| 623 |
+
labels = st.session_state["labels"]
|
| 624 |
+
|
| 625 |
+
# Предобработка
|
| 626 |
+
st.subheader("🔧 Предобработка")
|
| 627 |
+
preprocess_config = PreprocessingConfig(
|
| 628 |
+
lowercase=True,
|
| 629 |
+
remove_html=True,
|
| 630 |
+
lemmatize=False, # Упрощенно для скорости
|
| 631 |
+
remove_stopwords=False
|
| 632 |
+
)
|
| 633 |
+
preprocessor = TextPreprocessor(preprocess_config)
|
| 634 |
+
processed_texts = preprocessor.preprocess_batch(texts[:min(100, len(texts))]) # Ограничиваем для демо
|
| 635 |
+
|
| 636 |
+
# Векторизация
|
| 637 |
+
st.subheader("🧮 Векторизация")
|
| 638 |
+
vectorization_method = st.selectbox(
|
| 639 |
+
"Метод векторизации:",
|
| 640 |
+
["tfidf", "bow"]
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
if st.button("🔨 Векторизовать тексты", key="vectorize_for_classification"):
|
| 644 |
+
with st.spinner("Векторизация..."):
|
| 645 |
+
X, vectorizer = vectorize_with_classical(
|
| 646 |
+
processed_texts,
|
| 647 |
+
method=vectorization_method,
|
| 648 |
+
ngram_range=(1, 2),
|
| 649 |
+
max_features=1000
|
| 650 |
+
)
|
| 651 |
+
st.session_state["X_classification"] = X
|
| 652 |
+
st.session_state["vectorizer_classification"] = vectorizer
|
| 653 |
+
st.success(f"Векторизовано {len(processed_texts)} текстов, размерность: {X.shape}")
|
| 654 |
+
|
| 655 |
+
# Классификация
|
| 656 |
+
if "X_classification" in st.session_state:
|
| 657 |
+
X = st.session_state["X_classification"]
|
| 658 |
+
y = labels[:len(processed_texts)]
|
| 659 |
+
|
| 660 |
+
# Разделение на train/test
|
| 661 |
+
from sklearn.model_selection import train_test_split
|
| 662 |
+
# Для multilabel stratify не поддерживается напрямую
|
| 663 |
+
if task_type == "Многометочная":
|
| 664 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 665 |
+
X, y, test_size=0.2, random_state=42
|
| 666 |
+
)
|
| 667 |
+
else:
|
| 668 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 669 |
+
X, y, test_size=0.2, random_state=42, stratify=y
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
st.subheader("🎯 Обучение классификаторов")
|
| 673 |
+
selected_models = st.multiselect(
|
| 674 |
+
"Выберите модели:",
|
| 675 |
+
["Logistic Regression", "SVM", "Random Forest"],
|
| 676 |
+
default=["Logistic Regression", "Random Forest"]
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if st.button("🚀 Обучить модели", key="train_classifiers"):
|
| 680 |
+
configs = []
|
| 681 |
+
if "Logistic Regression" in selected_models:
|
| 682 |
+
configs.append(ClassifierConfig(name="Logistic Regression", model_type="lr"))
|
| 683 |
+
if "SVM" in selected_models:
|
| 684 |
+
configs.append(ClassifierConfig(name="SVM", model_type="svm", params={"kernel": "linear"}))
|
| 685 |
+
if "Random Forest" in selected_models:
|
| 686 |
+
configs.append(ClassifierConfig(name="Random Forest", model_type="rf"))
|
| 687 |
+
|
| 688 |
+
with st.spinner("Обучение моделей..."):
|
| 689 |
+
# Определяем тип задачи
|
| 690 |
+
if task_type == "Многометочная":
|
| 691 |
+
task_type_str = "multilabel"
|
| 692 |
+
elif task_type == "Многоклассовая":
|
| 693 |
+
task_type_str = "multiclass"
|
| 694 |
+
else:
|
| 695 |
+
task_type_str = "binary"
|
| 696 |
+
|
| 697 |
+
results_df = compare_classifiers(
|
| 698 |
+
X_train, y_train, X_test, y_test,
|
| 699 |
+
configs,
|
| 700 |
+
task_type=task_type_str
|
| 701 |
+
)
|
| 702 |
+
st.session_state["classification_results"] = results_df
|
| 703 |
+
|
| 704 |
+
if "classification_results" in st.session_state:
|
| 705 |
+
st.subheader("📊 Результаты классификации")
|
| 706 |
+
st.dataframe(st.session_state["classification_results"], use_container_width=True)
|
| 707 |
+
|
| 708 |
+
# Важность признаков
|
| 709 |
+
if "vectorizer_classification" in st.session_state:
|
| 710 |
+
st.subheader("🔍 Важные слова")
|
| 711 |
+
vectorizer = st.session_state["vectorizer_classification"]
|
| 712 |
+
if "Logistic Regression" in selected_models:
|
| 713 |
+
# Создаем простую модель для демонстрации
|
| 714 |
+
from sklearn.linear_model import LogisticRegression
|
| 715 |
+
model = LogisticRegression(max_iter=1000, random_state=42)
|
| 716 |
+
model.fit(X_train, y_train)
|
| 717 |
+
important_words = get_tfidf_important_words(vectorizer, model, class_idx=0, top_k=20)
|
| 718 |
+
st.dataframe(important_words, use_container_width=True)
|
| 719 |
+
|
| 720 |
+
# ======== Кластеризация (ЛР4) ========
|
| 721 |
+
with main_tabs[4]:
|
| 722 |
+
st.subheader("🔍 Кластеризация текстов")
|
| 723 |
+
|
| 724 |
+
if not texts:
|
| 725 |
+
st.warning("⚠️ Загрузите тексты для кластеризации.")
|
| 726 |
+
else:
|
| 727 |
+
# Предобработка
|
| 728 |
+
st.subheader("🔧 Предобработка")
|
| 729 |
+
preprocess_config = PreprocessingConfig(
|
| 730 |
+
lowercase=True,
|
| 731 |
+
remove_html=True,
|
| 732 |
+
lemmatize=False,
|
| 733 |
+
remove_stopwords=False
|
| 734 |
+
)
|
| 735 |
+
preprocessor = TextPreprocessor(preprocess_config)
|
| 736 |
+
processed_texts = preprocessor.preprocess_batch(texts[:min(200, len(texts))]) # Ограничиваем для демо
|
| 737 |
+
|
| 738 |
+
# Векторизация
|
| 739 |
+
st.subheader("🧮 Векторизация")
|
| 740 |
+
vectorization_method = st.selectbox(
|
| 741 |
+
"Метод векторизации:",
|
| 742 |
+
["tfidf", "bm25"],
|
| 743 |
+
key="clustering_vectorization"
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
if st.button("🔨 Векторизовать тексты", key="vectorize_for_clustering"):
|
| 747 |
+
with st.spinner("Векторизация..."):
|
| 748 |
+
try:
|
| 749 |
+
X, vectorizer_obj = vectorize_texts(
|
| 750 |
+
processed_texts,
|
| 751 |
+
method=vectorization_method,
|
| 752 |
+
max_features=500
|
| 753 |
+
)
|
| 754 |
+
st.session_state["X_clustering"] = X
|
| 755 |
+
st.session_state["vectorizer_clustering"] = vectorizer_obj
|
| 756 |
+
st.success(f"Векторизовано {len(processed_texts)} текстов, размерность: {X.shape}")
|
| 757 |
+
except Exception as e:
|
| 758 |
+
st.error(f"Ошибка векторизации: {e}")
|
| 759 |
+
|
| 760 |
+
# Кластеризация
|
| 761 |
+
if "X_clustering" in st.session_state:
|
| 762 |
+
X = st.session_state["X_clustering"]
|
| 763 |
+
|
| 764 |
+
st.subheader("🎯 Кластеризация")
|
| 765 |
+
clustering_method = st.selectbox(
|
| 766 |
+
"Метод кластеризации:",
|
| 767 |
+
["kmeans", "dbscan", "agglomerative", "gmm"],
|
| 768 |
+
key="clustering_method"
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
n_clusters = None
|
| 772 |
+
if clustering_method in ["kmeans", "agglomerative", "gmm"]:
|
| 773 |
+
n_clusters = st.slider("Число кластеров", 2, 20, 5, key="n_clusters")
|
| 774 |
+
|
| 775 |
+
if clustering_method == "dbscan":
|
| 776 |
+
eps = st.slider("EPS", 0.1, 1.0, 0.5, 0.1, key="dbscan_eps")
|
| 777 |
+
min_samples = st.slider("Min samples", 2, 10, 5, key="dbscan_min_samples")
|
| 778 |
+
else:
|
| 779 |
+
eps = 0.5
|
| 780 |
+
min_samples = 5
|
| 781 |
+
|
| 782 |
+
if st.button("🚀 Выполнить кластеризацию", key="run_clustering"):
|
| 783 |
+
with st.spinner("Кластеризация..."):
|
| 784 |
+
try:
|
| 785 |
+
config = ClusteringConfig(
|
| 786 |
+
method=clustering_method,
|
| 787 |
+
n_clusters=n_clusters,
|
| 788 |
+
eps=eps,
|
| 789 |
+
min_samples=min_samples
|
| 790 |
+
)
|
| 791 |
+
clusterer = ClusteringAlgorithms(config)
|
| 792 |
+
clusterer.fit(X)
|
| 793 |
+
|
| 794 |
+
# Оценка качества
|
| 795 |
+
metrics = evaluate_clustering(X, clusterer.labels_)
|
| 796 |
+
|
| 797 |
+
st.session_state["clustering_labels"] = clusterer.labels_
|
| 798 |
+
st.session_state["clustering_metrics"] = metrics
|
| 799 |
+
st.session_state["clustering_model"] = clusterer
|
| 800 |
+
|
| 801 |
+
st.success("Кластеризация завершена!")
|
| 802 |
+
except Exception as e:
|
| 803 |
+
st.error(f"Ошибка кластеризации: {e}")
|
| 804 |
+
|
| 805 |
+
if "clustering_labels" in st.session_state:
|
| 806 |
+
labels = st.session_state["clustering_labels"]
|
| 807 |
+
metrics = st.session_state["clustering_metrics"]
|
| 808 |
+
|
| 809 |
+
st.subheader("📊 Результаты кластеризации")
|
| 810 |
+
|
| 811 |
+
# Метрики
|
| 812 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 813 |
+
with col1:
|
| 814 |
+
st.metric("Число кластеров", metrics.get("n_clusters", 0))
|
| 815 |
+
with col2:
|
| 816 |
+
st.metric("Silhouette", round(metrics.get("silhouette", -1), 3))
|
| 817 |
+
with col3:
|
| 818 |
+
st.metric("Calinski-Harabasz", round(metrics.get("calinski_harabasz", 0), 2))
|
| 819 |
+
with col4:
|
| 820 |
+
st.metric("Davies-Bouldin", round(metrics.get("davies_bouldin", np.inf), 3))
|
| 821 |
+
|
| 822 |
+
# Распределение по кластерам
|
| 823 |
+
unique_labels, counts = np.unique(labels, return_counts=True)
|
| 824 |
+
cluster_df = pd.DataFrame({
|
| 825 |
+
"Кластер": unique_labels,
|
| 826 |
+
"Количество документов": counts
|
| 827 |
+
})
|
| 828 |
+
st.dataframe(cluster_df, use_container_width=True)
|
| 829 |
+
|
| 830 |
+
# Примеры документов из кластеров
|
| 831 |
+
st.subheader("📝 Примеры документов по кластерам")
|
| 832 |
+
selected_cluster = st.selectbox(
|
| 833 |
+
"Выберите кластер:",
|
| 834 |
+
unique_labels[unique_labels != -1] if -1 in labels else unique_labels,
|
| 835 |
+
key="selected_cluster"
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
cluster_indices = np.where(labels == selected_cluster)[0]
|
| 839 |
+
if len(cluster_indices) > 0:
|
| 840 |
+
sample_indices = cluster_indices[:5] # Показываем первые 5
|
| 841 |
+
for idx in sample_indices:
|
| 842 |
+
st.text_area(
|
| 843 |
+
f"Документ {idx}",
|
| 844 |
+
processed_texts[idx][:200] + "..." if len(processed_texts[idx]) > 200 else processed_texts[idx],
|
| 845 |
+
height=100,
|
| 846 |
+
key=f"doc_{idx}"
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
# Визуализация (если возможно)
|
| 850 |
+
if X.shape[1] > 2:
|
| 851 |
+
st.subheader("📈 Визуализация кластеров")
|
| 852 |
+
try:
|
| 853 |
+
from sklearn.decomposition import PCA
|
| 854 |
+
pca = PCA(n_components=2)
|
| 855 |
+
X_2d = pca.fit_transform(X)
|
| 856 |
+
|
| 857 |
+
import plotly.express as px
|
| 858 |
+
viz_df = pd.DataFrame({
|
| 859 |
+
"x": X_2d[:, 0],
|
| 860 |
+
"y": X_2d[:, 1],
|
| 861 |
+
"Кластер": labels.astype(str)
|
| 862 |
+
})
|
| 863 |
+
fig = px.scatter(viz_df, x="x", y="y", color="Кластер",
|
| 864 |
+
title="Проекция кластеров (PCA)")
|
| 865 |
+
fig.update_traces(marker_size=5)
|
| 866 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 867 |
+
except Exception as e:
|
| 868 |
+
st.warning(f"Не удалось создать визуализацию: {e}")
|
| 869 |
+
|
| 870 |
|
| 871 |
if __name__ == "__main__":
|
| 872 |
main()
|
src/text_preprocessing.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Модуль для предобработки текстовых данных для задач классификации.
|
| 3 |
+
Включает очистку, токенизацию, лемматизацию, векторизацию и извлечение мета-признаков.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from bs4 import BeautifulSoup
|
| 14 |
+
import spacy
|
| 15 |
+
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
| 16 |
+
from gensim.models import Word2Vec, FastText, Doc2Vec
|
| 17 |
+
from gensim.utils import simple_preprocess
|
| 18 |
+
|
| 19 |
+
from src.text_cleaner import clean_text, remove_html, normalize_whitespace
|
| 20 |
+
from src.classical_vectorizers import ClassicalVectorizers, VectorizationConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class PreprocessingConfig:
|
| 25 |
+
"""Конфигурация предобработки текста."""
|
| 26 |
+
lowercase: bool = True
|
| 27 |
+
remove_html: bool = True
|
| 28 |
+
remove_urls: bool = True
|
| 29 |
+
remove_emails: bool = True
|
| 30 |
+
remove_numbers: bool = False
|
| 31 |
+
lemmatize: bool = True
|
| 32 |
+
remove_stopwords: bool = False
|
| 33 |
+
min_token_length: int = 2
|
| 34 |
+
emoji_to_text: bool = True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TextPreprocessor:
|
| 38 |
+
"""Класс для предобработки текстов для классификации."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, config: Optional[PreprocessingConfig] = None):
|
| 41 |
+
self.config = config or PreprocessingConfig()
|
| 42 |
+
self.nlp = None
|
| 43 |
+
if self.config.lemmatize:
|
| 44 |
+
try:
|
| 45 |
+
self.nlp = spacy.load("ru_core_news_sm")
|
| 46 |
+
except OSError:
|
| 47 |
+
try:
|
| 48 |
+
self.nlp = spacy.load("ru_core_news_md")
|
| 49 |
+
except OSError:
|
| 50 |
+
print("⚠️ spaCy русская модель не найдена. Лемматизация отключена.")
|
| 51 |
+
self.config.lemmatize = False
|
| 52 |
+
|
| 53 |
+
def _remove_urls(self, text: str) -> str:
|
| 54 |
+
"""Удаляет URL из текста."""
|
| 55 |
+
url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
| 56 |
+
return re.sub(url_pattern, '', text)
|
| 57 |
+
|
| 58 |
+
def _remove_emails(self, text: str) -> str:
|
| 59 |
+
"""Удаляет email адреса из текста."""
|
| 60 |
+
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
| 61 |
+
return re.sub(email_pattern, '', text)
|
| 62 |
+
|
| 63 |
+
def _emoji_to_text(self, text: str) -> str:
|
| 64 |
+
"""Заменяет эмодзи на текстовое описание (упрощенная версия)."""
|
| 65 |
+
# Базовые замены для русскоязычного контекста
|
| 66 |
+
emoji_map = {
|
| 67 |
+
'😀': ' улыбка ',
|
| 68 |
+
'😃': ' радость ',
|
| 69 |
+
'😄': ' смех ',
|
| 70 |
+
'😁': ' веселье ',
|
| 71 |
+
'😆': ' хохот ',
|
| 72 |
+
'😅': ' пот ',
|
| 73 |
+
'😂': ' слезы радости ',
|
| 74 |
+
'🤣': ' хохот ',
|
| 75 |
+
'😊': ' улыбка ',
|
| 76 |
+
'😇': ' ангел ',
|
| 77 |
+
'🙂': ' улыбка ',
|
| 78 |
+
'🙃': ' перевернутое лицо ',
|
| 79 |
+
'😉': ' подмигивание ',
|
| 80 |
+
'😌': ' облегчение ',
|
| 81 |
+
'😍': ' любовь ',
|
| 82 |
+
'🥰': ' любовь ',
|
| 83 |
+
'😘': ' поцелуй ',
|
| 84 |
+
'😗': ' поцелуй ',
|
| 85 |
+
'😙': ' поцелуй ',
|
| 86 |
+
'😚': ' поцелуй ',
|
| 87 |
+
'😋': ' вкусно ',
|
| 88 |
+
'😛': ' язык ',
|
| 89 |
+
'😜': ' подмигивание ',
|
| 90 |
+
'😝': ' язык ',
|
| 91 |
+
'😞': ' грусть ',
|
| 92 |
+
'😟': ' беспокойство ',
|
| 93 |
+
'😠': ' злость ',
|
| 94 |
+
'😡': ' ярость ',
|
| 95 |
+
'😢': ' плач ',
|
| 96 |
+
'😣': ' страдание ',
|
| 97 |
+
'😤': ' упрямство ',
|
| 98 |
+
'😥': ' разочарование ',
|
| 99 |
+
'😦': ' удивление ',
|
| 100 |
+
'😧': ' шок ',
|
| 101 |
+
'😨': ' страх ',
|
| 102 |
+
'😩': ' усталость ',
|
| 103 |
+
'😪': ' сонливость ',
|
| 104 |
+
'😫': ' усталость ',
|
| 105 |
+
'😬': ' напряжение ',
|
| 106 |
+
'😭': ' плач ',
|
| 107 |
+
'😮': ' удивление ',
|
| 108 |
+
'😯': ' удивление ',
|
| 109 |
+
'😰': ' тревога ',
|
| 110 |
+
'😱': ' ужас ',
|
| 111 |
+
'😲': ' шок ',
|
| 112 |
+
'😳': ' смущение ',
|
| 113 |
+
'😴': ' сон ',
|
| 114 |
+
'😵': ' головокружение ',
|
| 115 |
+
'😶': ' без слов ',
|
| 116 |
+
'😷': ' маска ',
|
| 117 |
+
'🤐': ' молчание ',
|
| 118 |
+
'🤒': ' болезнь ',
|
| 119 |
+
'🤕': ' травма ',
|
| 120 |
+
'🤢': ' тошнота ',
|
| 121 |
+
'🤣': ' хохот ',
|
| 122 |
+
'🤤': ' слюни ',
|
| 123 |
+
'🤥': ' ложь ',
|
| 124 |
+
'🤧': ' чихание ',
|
| 125 |
+
'🤨': ' подозрение ',
|
| 126 |
+
'🤩': ' звезды ',
|
| 127 |
+
'🤪': ' безумие ',
|
| 128 |
+
'🤫': ' тишина ',
|
| 129 |
+
'🤬': ' ругательство ',
|
| 130 |
+
'🤭': ' секрет ',
|
| 131 |
+
'🤮': ' рвота ',
|
| 132 |
+
'🤯': ' взрыв мозга ',
|
| 133 |
+
}
|
| 134 |
+
for emoji, replacement in emoji_map.items():
|
| 135 |
+
text = text.replace(emoji, replacement)
|
| 136 |
+
return text
|
| 137 |
+
|
| 138 |
+
def preprocess(self, text: str) -> str:
|
| 139 |
+
"""Основная функция предобработки текста."""
|
| 140 |
+
if not text:
|
| 141 |
+
return ""
|
| 142 |
+
|
| 143 |
+
# Удаление HTML
|
| 144 |
+
if self.config.remove_html:
|
| 145 |
+
text = remove_html(text)
|
| 146 |
+
|
| 147 |
+
# Удаление URL
|
| 148 |
+
if self.config.remove_urls:
|
| 149 |
+
text = self._remove_urls(text)
|
| 150 |
+
|
| 151 |
+
# Удаление email
|
| 152 |
+
if self.config.remove_emails:
|
| 153 |
+
text = self._remove_emails(text)
|
| 154 |
+
|
| 155 |
+
# Замена эмодзи
|
| 156 |
+
if self.config.emoji_to_text:
|
| 157 |
+
text = self._emoji_to_text(text)
|
| 158 |
+
|
| 159 |
+
# Нормализация пробелов
|
| 160 |
+
text = normalize_whitespace(text)
|
| 161 |
+
|
| 162 |
+
# Приведение к нижнему регистру
|
| 163 |
+
if self.config.lowercase:
|
| 164 |
+
text = text.lower()
|
| 165 |
+
|
| 166 |
+
# Удаление чисел (опционально)
|
| 167 |
+
if self.config.remove_numbers:
|
| 168 |
+
text = re.sub(r'\d+', '', text)
|
| 169 |
+
|
| 170 |
+
# Лемматизация
|
| 171 |
+
if self.config.lemmatize and self.nlp:
|
| 172 |
+
doc = self.nlp(text)
|
| 173 |
+
tokens = [token.lemma_ for token in doc if not token.is_punct and not token.is_space]
|
| 174 |
+
text = ' '.join(tokens)
|
| 175 |
+
else:
|
| 176 |
+
# Простая токенизация
|
| 177 |
+
tokens = simple_preprocess(text, deacc=False, min_len=self.config.min_token_length)
|
| 178 |
+
text = ' '.join(tokens)
|
| 179 |
+
|
| 180 |
+
# Удаление стоп-слов (если не использовалась лемматизация со spaCy)
|
| 181 |
+
if self.config.remove_stopwords and not (self.config.lemmatize and self.nlp):
|
| 182 |
+
from src.text_cleaner import remove_stopwords_tokens
|
| 183 |
+
tokens = text.split()
|
| 184 |
+
tokens = remove_stopwords_tokens(tokens)
|
| 185 |
+
text = ' '.join(tokens)
|
| 186 |
+
|
| 187 |
+
# Финальная нормализация
|
| 188 |
+
text = normalize_whitespace(text)
|
| 189 |
+
|
| 190 |
+
return text
|
| 191 |
+
|
| 192 |
+
def preprocess_batch(self, texts: List[str]) -> List[str]:
|
| 193 |
+
"""Предобработка списка текстов."""
|
| 194 |
+
return [self.preprocess(text) for text in texts]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def extract_meta_features(texts: List[str]) -> np.ndarray:
|
| 198 |
+
"""
|
| 199 |
+
Извлекает мета-признаки из текстов.
|
| 200 |
+
|
| 201 |
+
Возвращает:
|
| 202 |
+
Массив формы (n_texts, n_features) с признаками:
|
| 203 |
+
- длина текста (символы)
|
| 204 |
+
- средняя длина слова
|
| 205 |
+
- количество уникальных слов
|
| 206 |
+
- доля знаков препинания
|
| 207 |
+
- доля заглавных букв
|
| 208 |
+
- доля цифр
|
| 209 |
+
"""
|
| 210 |
+
features = []
|
| 211 |
+
|
| 212 |
+
for text in texts:
|
| 213 |
+
if not text:
|
| 214 |
+
features.append([0, 0, 0, 0, 0, 0])
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
# Длина текста
|
| 218 |
+
text_length = len(text)
|
| 219 |
+
|
| 220 |
+
# Токены
|
| 221 |
+
tokens = text.split()
|
| 222 |
+
if not tokens:
|
| 223 |
+
features.append([text_length, 0, 0, 0, 0, 0])
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
# Средняя длина слова
|
| 227 |
+
avg_word_length = np.mean([len(token) for token in tokens])
|
| 228 |
+
|
| 229 |
+
# Количество уникальных слов
|
| 230 |
+
unique_words = len(set(tokens))
|
| 231 |
+
|
| 232 |
+
# Доля знаков препинания
|
| 233 |
+
punct_count = sum(1 for c in text if c in '.,;:!?()[]{}"\'-')
|
| 234 |
+
punct_ratio = punct_count / text_length if text_length > 0 else 0
|
| 235 |
+
|
| 236 |
+
# Доля заглавных букв
|
| 237 |
+
upper_count = sum(1 for c in text if c.isupper())
|
| 238 |
+
upper_ratio = upper_count / text_length if text_length > 0 else 0
|
| 239 |
+
|
| 240 |
+
# Доля цифр
|
| 241 |
+
digit_count = sum(1 for c in text if c.isdigit())
|
| 242 |
+
digit_ratio = digit_count / text_length if text_length > 0 else 0
|
| 243 |
+
|
| 244 |
+
features.append([
|
| 245 |
+
text_length,
|
| 246 |
+
avg_word_length,
|
| 247 |
+
unique_words,
|
| 248 |
+
punct_ratio,
|
| 249 |
+
upper_ratio,
|
| 250 |
+
digit_ratio
|
| 251 |
+
])
|
| 252 |
+
|
| 253 |
+
return np.array(features)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def vectorize_with_classical(texts: List[str], method: str = "tfidf",
|
| 257 |
+
ngram_range: Tuple[int, int] = (1, 2),
|
| 258 |
+
max_features: Optional[int] = None) -> Tuple[np.ndarray, Any]:
|
| 259 |
+
"""
|
| 260 |
+
Векторизация текстов классическими методами.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
texts: Список текстов
|
| 264 |
+
method: Метод векторизации (tfidf, bow)
|
| 265 |
+
ngram_range: Диапазон n-грамм
|
| 266 |
+
max_features: Максимальное количество признаков
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Матрица признаков и векторизатор
|
| 270 |
+
"""
|
| 271 |
+
config = VectorizationConfig(
|
| 272 |
+
method=method,
|
| 273 |
+
ngram_range=ngram_range,
|
| 274 |
+
max_features=max_features
|
| 275 |
+
)
|
| 276 |
+
vectorizer = ClassicalVectorizers(config)
|
| 277 |
+
X, _ = vectorizer.fit_transform(texts)
|
| 278 |
+
return X.toarray() if hasattr(X, 'toarray') else X, vectorizer
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def vectorize_with_embeddings(texts: List[str],
|
| 282 |
+
model: Any,
|
| 283 |
+
aggregation: str = "mean") -> np.ndarray:
|
| 284 |
+
"""
|
| 285 |
+
Векторизация текстов с использованием обученных эмбеддингов.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
texts: Список текстов (уже токенизированных)
|
| 289 |
+
model: Обученная модель (Word2Vec, FastText, Doc2Vec)
|
| 290 |
+
aggregation: Метод агрегации (mean, max, sum)
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
Матрица эмбеддингов документов
|
| 294 |
+
"""
|
| 295 |
+
if isinstance(model, Doc2Vec):
|
| 296 |
+
# Doc2Vec имеет встроенный метод для документов
|
| 297 |
+
vectors = []
|
| 298 |
+
for text in texts:
|
| 299 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 300 |
+
if tokens:
|
| 301 |
+
vec = model.infer_vector(tokens)
|
| 302 |
+
else:
|
| 303 |
+
vec = np.zeros(model.vector_size)
|
| 304 |
+
vectors.append(vec)
|
| 305 |
+
return np.array(vectors)
|
| 306 |
+
|
| 307 |
+
# Word2Vec / FastText
|
| 308 |
+
kv = model.wv if hasattr(model, 'wv') else model
|
| 309 |
+
vector_size = kv.vector_size if hasattr(kv, 'vector_size') else model.vector_size
|
| 310 |
+
|
| 311 |
+
vectors = []
|
| 312 |
+
for text in texts:
|
| 313 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 314 |
+
word_vectors = []
|
| 315 |
+
for token in tokens:
|
| 316 |
+
if token in kv:
|
| 317 |
+
word_vectors.append(kv[token])
|
| 318 |
+
|
| 319 |
+
if not word_vectors:
|
| 320 |
+
vectors.append(np.zeros(vector_size))
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
word_vectors = np.array(word_vectors)
|
| 324 |
+
|
| 325 |
+
if aggregation == "mean":
|
| 326 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 327 |
+
elif aggregation == "max":
|
| 328 |
+
doc_vector = np.max(word_vectors, axis=0)
|
| 329 |
+
elif aggregation == "sum":
|
| 330 |
+
doc_vector = np.sum(word_vectors, axis=0)
|
| 331 |
+
else:
|
| 332 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 333 |
+
|
| 334 |
+
vectors.append(doc_vector)
|
| 335 |
+
|
| 336 |
+
return np.array(vectors)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
# Тестирование
|
| 341 |
+
sample_texts = [
|
| 342 |
+
"Это тестовый текст для проверки предобработки. https://example.com test@email.ru",
|
| 343 |
+
"Второй текст с эмодзи 😀 и HTML <p>тегами</p>.",
|
| 344 |
+
"Третий текст 123 с числами и ПРОПИСНЫМИ буквами!"
|
| 345 |
+
]
|
| 346 |
+
|
| 347 |
+
config = PreprocessingConfig(
|
| 348 |
+
lowercase=True,
|
| 349 |
+
remove_html=True,
|
| 350 |
+
remove_urls=True,
|
| 351 |
+
remove_emails=True,
|
| 352 |
+
lemmatize=False, # Отключаем для теста
|
| 353 |
+
remove_stopwords=False
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
preprocessor = TextPreprocessor(config)
|
| 357 |
+
processed = preprocessor.preprocess_batch(sample_texts)
|
| 358 |
+
|
| 359 |
+
print("Обработанные тексты:")
|
| 360 |
+
for i, (orig, proc) in enumerate(zip(sample_texts, processed)):
|
| 361 |
+
print(f"\n{i+1}. Исходный: {orig[:50]}...")
|
| 362 |
+
print(f" Обработанный: {proc[:50]}...")
|
| 363 |
+
|
| 364 |
+
# Мета-признаки
|
| 365 |
+
meta_features = extract_meta_features(processed)
|
| 366 |
+
print(f"\nМета-признаки (форма: {meta_features.shape}):")
|
| 367 |
+
print(meta_features)
|
| 368 |
+
|
src/text_to_vector.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Модуль для векторизации текстов для кластеризации.
|
| 3 |
+
Использует модели из ЛР2: Word2Vec, FastText, GloVe, а также TF-IDF и BM25.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 12 |
+
from sklearn.preprocessing import normalize
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from rank_bm25 import BM25Okapi
|
| 16 |
+
BM25_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
BM25_AVAILABLE = False
|
| 19 |
+
print("⚠️ rank-bm25 не установлен. BM25 недоступен. Установите: pip install rank-bm25")
|
| 20 |
+
|
| 21 |
+
from gensim.models import Word2Vec, FastText, Doc2Vec
|
| 22 |
+
from gensim.utils import simple_preprocess
|
| 23 |
+
|
| 24 |
+
from src.classical_vectorizers import ClassicalVectorizers, VectorizationConfig
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_embedding_model(model_path: str):
|
| 28 |
+
"""
|
| 29 |
+
Загружает обученную модель эмбеддингов из ЛР2.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_path: Путь к модели
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Загруженная модель (Word2Vec, FastText или Doc2Vec)
|
| 36 |
+
"""
|
| 37 |
+
if not os.path.exists(model_path):
|
| 38 |
+
raise FileNotFoundError(f"Модель не найдена: {model_path}")
|
| 39 |
+
|
| 40 |
+
# Пробуем загрузить как Word2Vec
|
| 41 |
+
try:
|
| 42 |
+
return Word2Vec.load(model_path)
|
| 43 |
+
except:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
# Пробуем загрузить как FastText
|
| 47 |
+
try:
|
| 48 |
+
return FastText.load(model_path)
|
| 49 |
+
except:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
# Пробуем загрузить как Doc2Vec
|
| 53 |
+
try:
|
| 54 |
+
return Doc2Vec.load(model_path)
|
| 55 |
+
except:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
raise ValueError(f"Не удалось загрузить модель из {model_path}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def vectorize_tfidf(texts: List[str],
|
| 62 |
+
max_features: Optional[int] = None,
|
| 63 |
+
ngram_range: Tuple[int, int] = (1, 2),
|
| 64 |
+
normalize_vectors: bool = True) -> Tuple[np.ndarray, Any]:
|
| 65 |
+
"""
|
| 66 |
+
Векторизация текстов с помощью TF-IDF.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
texts: Список текстов
|
| 70 |
+
max_features: Максимальное количество признаков
|
| 71 |
+
ngram_range: Диапазон n-грамм
|
| 72 |
+
normalize_vectors: Нормализовать ли векторы (L2)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Матрица векторов и векторизатор
|
| 76 |
+
"""
|
| 77 |
+
vectorizer = TfidfVectorizer(
|
| 78 |
+
max_features=max_features,
|
| 79 |
+
ngram_range=ngram_range,
|
| 80 |
+
lowercase=True,
|
| 81 |
+
min_df=1
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
X = vectorizer.fit_transform(texts).toarray()
|
| 85 |
+
|
| 86 |
+
if normalize_vectors:
|
| 87 |
+
X = normalize(X, norm='l2')
|
| 88 |
+
|
| 89 |
+
return X, vectorizer
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def vectorize_bm25(texts: List[str],
|
| 93 |
+
tokenize: bool = True) -> Tuple[np.ndarray, Any]:
|
| 94 |
+
"""
|
| 95 |
+
Векторизация текстов с помощью BM25.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
texts: Список текстов
|
| 99 |
+
tokenize: Токенизировать ли тексты
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Матрица векторов и BM25 объект
|
| 103 |
+
"""
|
| 104 |
+
if not BM25_AVAILABLE:
|
| 105 |
+
raise ImportError("rank-bm25 не установлен. Установите: pip install rank-bm25")
|
| 106 |
+
|
| 107 |
+
if tokenize:
|
| 108 |
+
tokenized_texts = [simple_preprocess(text, deacc=False, min_len=1) for text in texts]
|
| 109 |
+
else:
|
| 110 |
+
tokenized_texts = [text.split() for text in texts]
|
| 111 |
+
|
| 112 |
+
bm25 = BM25Okapi(tokenized_texts)
|
| 113 |
+
|
| 114 |
+
# Создаем матрицу BM25 для всех документов
|
| 115 |
+
X = np.array([bm25.get_scores(doc) for doc in tokenized_texts])
|
| 116 |
+
|
| 117 |
+
# Нормализуем
|
| 118 |
+
X = normalize(X, norm='l2')
|
| 119 |
+
|
| 120 |
+
return X, bm25
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def vectorize_with_word2vec(texts: List[str],
|
| 124 |
+
model: Word2Vec,
|
| 125 |
+
aggregation: str = "mean",
|
| 126 |
+
normalize_vectors: bool = True) -> np.ndarray:
|
| 127 |
+
"""
|
| 128 |
+
Векторизация текстов с помощью Word2Vec модели из ЛР2.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
texts: Список текстов
|
| 132 |
+
model: Обученная Word2Vec модель
|
| 133 |
+
aggregation: Метод агрегации (mean, max, sum)
|
| 134 |
+
normalize_vectors: Нормализовать ли векторы (L2)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Матрица векторов документов
|
| 138 |
+
"""
|
| 139 |
+
kv = model.wv
|
| 140 |
+
vector_size = kv.vector_size
|
| 141 |
+
vectors = []
|
| 142 |
+
|
| 143 |
+
for text in texts:
|
| 144 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 145 |
+
word_vectors = []
|
| 146 |
+
|
| 147 |
+
for token in tokens:
|
| 148 |
+
if token in kv:
|
| 149 |
+
word_vectors.append(kv[token])
|
| 150 |
+
|
| 151 |
+
if not word_vectors:
|
| 152 |
+
vectors.append(np.zeros(vector_size))
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
word_vectors = np.array(word_vectors)
|
| 156 |
+
|
| 157 |
+
if aggregation == "mean":
|
| 158 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 159 |
+
elif aggregation == "max":
|
| 160 |
+
doc_vector = np.max(word_vectors, axis=0)
|
| 161 |
+
elif aggregation == "sum":
|
| 162 |
+
doc_vector = np.sum(word_vectors, axis=0)
|
| 163 |
+
else:
|
| 164 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 165 |
+
|
| 166 |
+
vectors.append(doc_vector)
|
| 167 |
+
|
| 168 |
+
X = np.array(vectors)
|
| 169 |
+
|
| 170 |
+
if normalize_vectors:
|
| 171 |
+
X = normalize(X, norm='l2')
|
| 172 |
+
|
| 173 |
+
return X
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def vectorize_with_fasttext(texts: List[str],
|
| 177 |
+
model: FastText,
|
| 178 |
+
aggregation: str = "mean",
|
| 179 |
+
normalize_vectors: bool = True) -> np.ndarray:
|
| 180 |
+
"""
|
| 181 |
+
Векторизация текстов с помощью FastText модели из ЛР2.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
texts: Список текстов
|
| 185 |
+
model: Обученная FastText модель
|
| 186 |
+
aggregation: Метод агрегации (mean, max, sum)
|
| 187 |
+
normalize_vectors: Нормализовать ли векторы (L2)
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Матрица векторов документов
|
| 191 |
+
"""
|
| 192 |
+
kv = model.wv
|
| 193 |
+
vector_size = kv.vector_size
|
| 194 |
+
vectors = []
|
| 195 |
+
|
| 196 |
+
for text in texts:
|
| 197 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 198 |
+
word_vectors = []
|
| 199 |
+
|
| 200 |
+
for token in tokens:
|
| 201 |
+
# FastText может обрабатывать OOV слова
|
| 202 |
+
if token in kv:
|
| 203 |
+
word_vectors.append(kv[token])
|
| 204 |
+
else:
|
| 205 |
+
# Получаем вектор для OOV слова
|
| 206 |
+
word_vectors.append(kv.get_vector(token))
|
| 207 |
+
|
| 208 |
+
if not word_vectors:
|
| 209 |
+
vectors.append(np.zeros(vector_size))
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
word_vectors = np.array(word_vectors)
|
| 213 |
+
|
| 214 |
+
if aggregation == "mean":
|
| 215 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 216 |
+
elif aggregation == "max":
|
| 217 |
+
doc_vector = np.max(word_vectors, axis=0)
|
| 218 |
+
elif aggregation == "sum":
|
| 219 |
+
doc_vector = np.sum(word_vectors, axis=0)
|
| 220 |
+
else:
|
| 221 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 222 |
+
|
| 223 |
+
vectors.append(doc_vector)
|
| 224 |
+
|
| 225 |
+
X = np.array(vectors)
|
| 226 |
+
|
| 227 |
+
if normalize_vectors:
|
| 228 |
+
X = normalize(X, norm='l2')
|
| 229 |
+
|
| 230 |
+
return X
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def vectorize_with_doc2vec(texts: List[str],
|
| 234 |
+
model: Doc2Vec,
|
| 235 |
+
normalize_vectors: bool = True) -> np.ndarray:
|
| 236 |
+
"""
|
| 237 |
+
Векторизация текстов с помощью Doc2Vec модели из ЛР2.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
texts: Список текстов
|
| 241 |
+
model: Обученная Doc2Vec модель
|
| 242 |
+
normalize_vectors: Нормализовать ли векторы (L2)
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Матрица векторов документов
|
| 246 |
+
"""
|
| 247 |
+
vectors = []
|
| 248 |
+
|
| 249 |
+
for text in texts:
|
| 250 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 251 |
+
if tokens:
|
| 252 |
+
vec = model.infer_vector(tokens)
|
| 253 |
+
else:
|
| 254 |
+
vec = np.zeros(model.vector_size)
|
| 255 |
+
vectors.append(vec)
|
| 256 |
+
|
| 257 |
+
X = np.array(vectors)
|
| 258 |
+
|
| 259 |
+
if normalize_vectors:
|
| 260 |
+
X = normalize(X, norm='l2')
|
| 261 |
+
|
| 262 |
+
return X
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def vectorize_with_glove(texts: List[str],
|
| 266 |
+
model_path: str,
|
| 267 |
+
aggregation: str = "mean",
|
| 268 |
+
normalize_vectors: bool = True) -> np.ndarray:
|
| 269 |
+
"""
|
| 270 |
+
Векторизация текстов с помощью GloVe модели из ЛР2.
|
| 271 |
+
|
| 272 |
+
Примечание: GloVe обычно хранится в формате текстового файла или через gensim.
|
| 273 |
+
Здесь предполагается, что модель загружена через gensim или аналогичный интерфейс.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
texts: Список текстов
|
| 277 |
+
model_path: Путь к модели GloVe
|
| 278 |
+
aggregation: Метод агрегации (mean, max, sum)
|
| 279 |
+
normalize_vectors: Нормализовать ли векторы (L2)
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Матрица векторов документов
|
| 283 |
+
"""
|
| 284 |
+
# Пробуем загрузить как KeyedVectors (если сохранено через gensim)
|
| 285 |
+
try:
|
| 286 |
+
from gensim.models import KeyedVectors
|
| 287 |
+
kv = KeyedVectors.load(model_path)
|
| 288 |
+
except:
|
| 289 |
+
# Если не получилось, пробуем загрузить как Word2Vec (совместимость)
|
| 290 |
+
try:
|
| 291 |
+
model = Word2Vec.load(model_path)
|
| 292 |
+
kv = model.wv
|
| 293 |
+
except:
|
| 294 |
+
raise ValueError(f"Не удалось загрузить GloVe модель из {model_path}")
|
| 295 |
+
|
| 296 |
+
vector_size = kv.vector_size
|
| 297 |
+
vectors = []
|
| 298 |
+
|
| 299 |
+
for text in texts:
|
| 300 |
+
tokens = simple_preprocess(text, deacc=False, min_len=1)
|
| 301 |
+
word_vectors = []
|
| 302 |
+
|
| 303 |
+
for token in tokens:
|
| 304 |
+
if token in kv:
|
| 305 |
+
word_vectors.append(kv[token])
|
| 306 |
+
|
| 307 |
+
if not word_vectors:
|
| 308 |
+
vectors.append(np.zeros(vector_size))
|
| 309 |
+
continue
|
| 310 |
+
|
| 311 |
+
word_vectors = np.array(word_vectors)
|
| 312 |
+
|
| 313 |
+
if aggregation == "mean":
|
| 314 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 315 |
+
elif aggregation == "max":
|
| 316 |
+
doc_vector = np.max(word_vectors, axis=0)
|
| 317 |
+
elif aggregation == "sum":
|
| 318 |
+
doc_vector = np.sum(word_vectors, axis=0)
|
| 319 |
+
else:
|
| 320 |
+
doc_vector = np.mean(word_vectors, axis=0)
|
| 321 |
+
|
| 322 |
+
vectors.append(doc_vector)
|
| 323 |
+
|
| 324 |
+
X = np.array(vectors)
|
| 325 |
+
|
| 326 |
+
if normalize_vectors:
|
| 327 |
+
X = normalize(X, norm='l2')
|
| 328 |
+
|
| 329 |
+
return X
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def vectorize_texts(texts: List[str],
|
| 333 |
+
method: str = "tfidf",
|
| 334 |
+
model_path: Optional[str] = None,
|
| 335 |
+
**kwargs) -> Tuple[np.ndarray, Any]:
|
| 336 |
+
"""
|
| 337 |
+
Универсальная функция векторизации текстов.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
texts: Список текстов
|
| 341 |
+
method: Метод векторизации (tfidf, bm25, w2v, fasttext, doc2vec, glove)
|
| 342 |
+
model_path: Путь к модели (для w2v, fasttext, doc2vec, glove)
|
| 343 |
+
**kwargs: Дополнительные параметры
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
Матрица векторов и объект векторизатора/модели
|
| 347 |
+
"""
|
| 348 |
+
method = method.lower()
|
| 349 |
+
|
| 350 |
+
if method == "tfidf":
|
| 351 |
+
return vectorize_tfidf(texts, **kwargs)
|
| 352 |
+
|
| 353 |
+
elif method == "bm25":
|
| 354 |
+
return vectorize_bm25(texts, **kwargs)
|
| 355 |
+
|
| 356 |
+
elif method == "w2v" or method == "word2vec":
|
| 357 |
+
if model_path is None:
|
| 358 |
+
raise ValueError("Для Word2Vec требуется model_path")
|
| 359 |
+
model = load_embedding_model(model_path)
|
| 360 |
+
X = vectorize_with_word2vec(texts, model, **kwargs)
|
| 361 |
+
return X, model
|
| 362 |
+
|
| 363 |
+
elif method == "fasttext":
|
| 364 |
+
if model_path is None:
|
| 365 |
+
raise ValueError("Для FastText требуется model_path")
|
| 366 |
+
model = load_embedding_model(model_path)
|
| 367 |
+
X = vectorize_with_fasttext(texts, model, **kwargs)
|
| 368 |
+
return X, model
|
| 369 |
+
|
| 370 |
+
elif method == "doc2vec" or method == "d2v":
|
| 371 |
+
if model_path is None:
|
| 372 |
+
raise ValueError("Для Doc2Vec требуется model_path")
|
| 373 |
+
model = load_embedding_model(model_path)
|
| 374 |
+
X = vectorize_with_doc2vec(texts, model, **kwargs)
|
| 375 |
+
return X, model
|
| 376 |
+
|
| 377 |
+
elif method == "glove":
|
| 378 |
+
if model_path is None:
|
| 379 |
+
raise ValueError("Для GloVe требуется model_path")
|
| 380 |
+
X = vectorize_with_glove(texts, model_path, **kwargs)
|
| 381 |
+
return X, None
|
| 382 |
+
|
| 383 |
+
else:
|
| 384 |
+
raise ValueError(f"Неизвестный метод векторизации: {method}")
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
# Тестирование
|
| 389 |
+
sample_texts = [
|
| 390 |
+
"Это первый тестовый текст для проверки векторизации",
|
| 391 |
+
"Второй текст содержит другую информацию",
|
| 392 |
+
"Третий текст также используется для тестирования"
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
# TF-IDF
|
| 396 |
+
X_tfidf, vectorizer = vectorize_tfidf(sample_texts)
|
| 397 |
+
print(f"TF-IDF векторы: форма {X_tfidf.shape}")
|
| 398 |
+
|
| 399 |
+
# BM25 (если доступен)
|
| 400 |
+
if BM25_AVAILABLE:
|
| 401 |
+
X_bm25, bm25 = vectorize_bm25(sample_texts)
|
| 402 |
+
print(f"BM25 векторы: форма {X_bm25.shape}")
|
| 403 |
+
|