M.E.S.A._Intentions / ml_classifier.py
Alex-Watchman's picture
Upload fast parser and ml evaluator
7d9276b verified
# 📄 src/core/intent_parser/ml_classifier.py
import json
import os
import logging
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from torch.quantization import quantize_dynamic
import time
# Импорты с обработкой ошибок
print("Инициализация ML классификатора...")
try:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
ML_AVAILABLE = True
except ImportError as e:
print(f"⚠️ ML библиотеки не установлены: {e}")
ML_AVAILABLE = False
torch = None
AutoTokenizer = None
AutoModelForSequenceClassification = None
@dataclass
class MLClassificationResult:
"""Результат классификации ML моделью"""
intent: str
confidence: float
all_predictions: List[tuple] # Список всех (интент, уверенность)
multi_label_predictions: Optional[List[tuple]] = None # Интенты выше порога
class MLIntentClassifier:
"""
ML классификатор намерений на основе DistilBERT.
Поддерживает multi-label классификацию как в обученной модели.
"""
def __init__(self, model_path: Optional[str] = None):
self.logger = logging.getLogger(__name__)
self.model = None
self.tokenizer = None
self.device = None
self.is_initialized = False
# Словарь интентов
self.intent_to_idx = {}
self.idx_to_intent = {}
# Настройки
self.confidence_threshold = 0.3
self.max_length = 128
# Путь к модели (по умолчанию из вашей структуры)
if model_path is None:
# Автоматически определяем путь в структуре проекта
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
base_dir = "C:/PycharmProjects/Ariel"
model_path = os.path.join(base_dir, "Data", "Models", "intent_classifier")
base_dir = "C:/PycharmProjects/Ariel"
model_path = os.path.join(base_dir, "Data", "models", "intent_classifier")
self.model_path = model_path
self._initialize_model()
def _initialize_model(self):
"""Инициализация модели с обработкой ошибок"""
if not ML_AVAILABLE:
self.logger.warning("ML библиотеки не установлены. Использование заглушки.")
return
try:
# Проверяем существование директории
if not os.path.exists(self.model_path):
self.logger.error(f"Не найден файл: {self.model_path}")
self.logger.info("Проверьте, что вы распаковали архив в правильную папку")
return
# Проверяем наличие ключевых файлов
required_files = ['config.json']
weight_files = ['model.safetensors', 'pytorch_model.bin']
for file in required_files:
if not os.path.exists(os.path.join(self.model_path, file)):
self.logger.error(f"Не найден файл: {os.path.join(self.model_path, file)}")
return
# Проверяем наличие файла весов
has_weights = any(os.path.exists(os.path.join(self.model_path, wf)) for wf in weight_files)
if not has_weights:
self.logger.error(f"Не найден файл весов модели. Ожидается один из: {weight_files}")
self.logger.info(f"Файлы в директории: {os.listdir(self.model_path)}")
return
# Загружаем vocabulary интентов
vocab_path = os.path.join(self.model_path, "intent_vocab.json")
if os.path.exists(vocab_path):
with open(vocab_path, 'r', encoding='utf-8') as f:
self.intent_to_idx = json.load(f)
# Преобразуем индексы в int если они строки
self.intent_to_idx = {k: int(v) for k, v in self.intent_to_idx.items()}
self.idx_to_intent = {v: k for k, v in self.intent_to_idx.items()}
self.logger.info(f"Загружен словарь интентов: {len(self.intent_to_idx)} классов")
else:
self.logger.warning("Файл intent_vocab.json не найден. Пытаюсь определить из config.json")
# Попробуем получить из конфига модели
pass
# Загружаем модель и токенизатор
self.logger.info(f"Загрузка модели из {self.model_path}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True)
# Загружаем модель с multi-label конфигурацией
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_path,
local_files_only=True,
problem_type="multi_label_classification"
)
self.model = quantize_dynamic(self.model, {torch.nn.Linear}, dtype=torch.qint8)
# Настройка устройства
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
self.is_initialized = True
self.logger.info(f"✅ Модель загружена успешно!")
self.logger.info(f" Устройство: {self.device}")
self.logger.info(f" Классов: {len(self.intent_to_idx) if self.intent_to_idx else 'неизвестно'}")
except Exception as e:
self.logger.error(f"❌ Ошибка загрузки модели: {e}")
self.is_initialized = False
def predict(self, text: str, threshold: Optional[float] = None) -> MLClassificationResult:
"""Предсказание интентов для текста (multi-label)"""
if not self.is_initialized:
self.logger.warning("Модель не инициализирована, возвращаем fallback")
return self._fallback_prediction(text)
try:
current_threshold = threshold if threshold is not None else self.confidence_threshold
# Токенизация
inputs = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
)
# Переносим на нужное устройство
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Предсказание
with torch.no_grad():
outputs = self.model(**inputs)
# Для multi-label используем sigmoid
probabilities = torch.sigmoid(outputs.logits)
# Получаем numpy массив
probs = probabilities.cpu().numpy()[0]
# Собираем результаты
all_predictions = []
multi_label_predictions = []
for idx, prob in enumerate(probs):
if idx in self.idx_to_intent:
intent_name = self.idx_to_intent[idx]
confidence = float(prob)
all_predictions.append((intent_name, confidence))
if confidence >= current_threshold:
multi_label_predictions.append((intent_name, confidence))
# Сортируем по уверенности
all_predictions.sort(key=lambda x: x[1], reverse=True)
multi_label_predictions.sort(key=lambda x: x[1], reverse=True)
# Определяем основной интент
main_intent = "unknown"
main_confidence = 0.0
if multi_label_predictions:
main_intent = multi_label_predictions[0][0]
main_confidence = multi_label_predictions[0][1]
elif all_predictions:
main_intent = all_predictions[0][0]
main_confidence = all_predictions[0][1]
return MLClassificationResult(
intent=main_intent,
confidence=main_confidence,
all_predictions=all_predictions,
multi_label_predictions=multi_label_predictions
)
except Exception as e:
self.logger.error(f"Ошибка предсказания: {e}")
return self._fallback_prediction(text)
def _fallback_prediction(self, text: str) -> MLClassificationResult:
"""Заглушка при ошибках"""
return MLClassificationResult(
intent="unknown",
confidence=0.5,
all_predictions=[("unknown", 1.0)],
multi_label_predictions=[]
)
def get_model_info(self) -> Dict[str, Any]:
"""Информация о модели"""
return {
"is_initialized": self.is_initialized,
"model_path": self.model_path,
"num_intents": len(self.intent_to_idx),
"intents": list(self.intent_to_idx.keys()) if self.intent_to_idx else [],
"confidence_threshold": self.confidence_threshold,
"device": str(self.device) if self.device else None
}
def create_ml_classifier(model_path: Optional[str] = None) -> MLIntentClassifier:
"""Фабричная функция для создания классификатора"""
return MLIntentClassifier(model_path)
# Пример использования (раскомментировать):
start = time.time()
print("Загрузка модели, ожидайте...")
classifier = create_ml_classifier("/Data/Models/intent_classifier")
print("✅ Модель загружена! Тестируйте:")
while True:
text = input("\nВведите текст: ")
if text.lower() == 'выход': break
result = classifier.predict(text)
print(f"Результат: {result.intent} ({result.confidence:.1%})")
for intent, conf in result.all_predictions[:3]:
print(f" - {intent}: {conf:.1%}")