|
|
| import json
|
| import os
|
| import logging
|
| from typing import Dict, List, Optional, Any
|
| from dataclasses import dataclass
|
| from torch.quantization import quantize_dynamic
|
| import time
|
|
|
|
|
| print("Инициализация ML классификатора...")
|
| try:
|
| import torch
|
| import torch.nn.functional as F
|
| from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| ML_AVAILABLE = True
|
| except ImportError as e:
|
| print(f"⚠️ ML библиотеки не установлены: {e}")
|
| ML_AVAILABLE = False
|
| torch = None
|
| AutoTokenizer = None
|
| AutoModelForSequenceClassification = None
|
|
|
| @dataclass
|
| class MLClassificationResult:
|
| """Результат классификации ML моделью"""
|
| intent: str
|
| confidence: float
|
| all_predictions: List[tuple]
|
| multi_label_predictions: Optional[List[tuple]] = None
|
|
|
| class MLIntentClassifier:
|
| """
|
| ML классификатор намерений на основе DistilBERT.
|
| Поддерживает multi-label классификацию как в обученной модели.
|
| """
|
|
|
| def __init__(self, model_path: Optional[str] = None):
|
| self.logger = logging.getLogger(__name__)
|
| self.model = None
|
| self.tokenizer = None
|
| self.device = None
|
| self.is_initialized = False
|
|
|
|
|
| self.intent_to_idx = {}
|
| self.idx_to_intent = {}
|
|
|
|
|
| self.confidence_threshold = 0.3
|
| self.max_length = 128
|
|
|
|
|
| if model_path is None:
|
|
|
| base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| base_dir = "C:/PycharmProjects/Ariel"
|
| model_path = os.path.join(base_dir, "Data", "Models", "intent_classifier")
|
|
|
| base_dir = "C:/PycharmProjects/Ariel"
|
| model_path = os.path.join(base_dir, "Data", "models", "intent_classifier")
|
|
|
| self.model_path = model_path
|
| self._initialize_model()
|
|
|
| def _initialize_model(self):
|
| """Инициализация модели с обработкой ошибок"""
|
| if not ML_AVAILABLE:
|
| self.logger.warning("ML библиотеки не установлены. Использование заглушки.")
|
| return
|
|
|
| try:
|
|
|
| if not os.path.exists(self.model_path):
|
| self.logger.error(f"Не найден файл: {self.model_path}")
|
| self.logger.info("Проверьте, что вы распаковали архив в правильную папку")
|
| return
|
|
|
|
|
| required_files = ['config.json']
|
| weight_files = ['model.safetensors', 'pytorch_model.bin']
|
|
|
| for file in required_files:
|
| if not os.path.exists(os.path.join(self.model_path, file)):
|
| self.logger.error(f"Не найден файл: {os.path.join(self.model_path, file)}")
|
| return
|
|
|
|
|
| has_weights = any(os.path.exists(os.path.join(self.model_path, wf)) for wf in weight_files)
|
| if not has_weights:
|
| self.logger.error(f"Не найден файл весов модели. Ожидается один из: {weight_files}")
|
| self.logger.info(f"Файлы в директории: {os.listdir(self.model_path)}")
|
| return
|
|
|
|
|
| vocab_path = os.path.join(self.model_path, "intent_vocab.json")
|
| if os.path.exists(vocab_path):
|
| with open(vocab_path, 'r', encoding='utf-8') as f:
|
| self.intent_to_idx = json.load(f)
|
|
|
| self.intent_to_idx = {k: int(v) for k, v in self.intent_to_idx.items()}
|
| self.idx_to_intent = {v: k for k, v in self.intent_to_idx.items()}
|
| self.logger.info(f"Загружен словарь интентов: {len(self.intent_to_idx)} классов")
|
| else:
|
| self.logger.warning("Файл intent_vocab.json не найден. Пытаюсь определить из config.json")
|
|
|
| pass
|
|
|
|
|
| self.logger.info(f"Загрузка модели из {self.model_path}...")
|
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True)
|
|
|
|
|
| self.model = AutoModelForSequenceClassification.from_pretrained(
|
| self.model_path,
|
| local_files_only=True,
|
| problem_type="multi_label_classification"
|
| )
|
|
|
| self.model = quantize_dynamic(self.model, {torch.nn.Linear}, dtype=torch.qint8)
|
|
|
|
|
|
|
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.model.to(self.device)
|
| self.model.eval()
|
|
|
| self.is_initialized = True
|
| self.logger.info(f"✅ Модель загружена успешно!")
|
| self.logger.info(f" Устройство: {self.device}")
|
| self.logger.info(f" Классов: {len(self.intent_to_idx) if self.intent_to_idx else 'неизвестно'}")
|
|
|
| except Exception as e:
|
| self.logger.error(f"❌ Ошибка загрузки модели: {e}")
|
| self.is_initialized = False
|
|
|
| def predict(self, text: str, threshold: Optional[float] = None) -> MLClassificationResult:
|
| """Предсказание интентов для текста (multi-label)"""
|
| if not self.is_initialized:
|
| self.logger.warning("Модель не инициализирована, возвращаем fallback")
|
| return self._fallback_prediction(text)
|
|
|
| try:
|
| current_threshold = threshold if threshold is not None else self.confidence_threshold
|
|
|
|
|
| inputs = self.tokenizer(
|
| text,
|
| truncation=True,
|
| padding='max_length',
|
| max_length=self.max_length,
|
| return_tensors="pt"
|
| )
|
|
|
|
|
| inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
|
| with torch.no_grad():
|
| outputs = self.model(**inputs)
|
|
|
| probabilities = torch.sigmoid(outputs.logits)
|
|
|
|
|
| probs = probabilities.cpu().numpy()[0]
|
|
|
|
|
| all_predictions = []
|
| multi_label_predictions = []
|
|
|
| for idx, prob in enumerate(probs):
|
| if idx in self.idx_to_intent:
|
| intent_name = self.idx_to_intent[idx]
|
| confidence = float(prob)
|
|
|
| all_predictions.append((intent_name, confidence))
|
|
|
| if confidence >= current_threshold:
|
| multi_label_predictions.append((intent_name, confidence))
|
|
|
|
|
| all_predictions.sort(key=lambda x: x[1], reverse=True)
|
| multi_label_predictions.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
| main_intent = "unknown"
|
| main_confidence = 0.0
|
|
|
| if multi_label_predictions:
|
| main_intent = multi_label_predictions[0][0]
|
| main_confidence = multi_label_predictions[0][1]
|
| elif all_predictions:
|
| main_intent = all_predictions[0][0]
|
| main_confidence = all_predictions[0][1]
|
|
|
| return MLClassificationResult(
|
| intent=main_intent,
|
| confidence=main_confidence,
|
| all_predictions=all_predictions,
|
| multi_label_predictions=multi_label_predictions
|
| )
|
|
|
| except Exception as e:
|
| self.logger.error(f"Ошибка предсказания: {e}")
|
| return self._fallback_prediction(text)
|
|
|
| def _fallback_prediction(self, text: str) -> MLClassificationResult:
|
| """Заглушка при ошибках"""
|
| return MLClassificationResult(
|
| intent="unknown",
|
| confidence=0.5,
|
| all_predictions=[("unknown", 1.0)],
|
| multi_label_predictions=[]
|
| )
|
|
|
| def get_model_info(self) -> Dict[str, Any]:
|
| """Информация о модели"""
|
| return {
|
| "is_initialized": self.is_initialized,
|
| "model_path": self.model_path,
|
| "num_intents": len(self.intent_to_idx),
|
| "intents": list(self.intent_to_idx.keys()) if self.intent_to_idx else [],
|
| "confidence_threshold": self.confidence_threshold,
|
| "device": str(self.device) if self.device else None
|
| }
|
|
|
| def create_ml_classifier(model_path: Optional[str] = None) -> MLIntentClassifier:
|
| """Фабричная функция для создания классификатора"""
|
| return MLIntentClassifier(model_path)
|
|
|
|
|
|
|
|
|
| start = time.time()
|
| print("Загрузка модели, ожидайте...")
|
| classifier = create_ml_classifier("/Data/Models/intent_classifier")
|
|
|
| print("✅ Модель загружена! Тестируйте:")
|
| while True:
|
| text = input("\nВведите текст: ")
|
| if text.lower() == 'выход': break
|
| result = classifier.predict(text)
|
| print(f"Результат: {result.intent} ({result.confidence:.1%})")
|
| for intent, conf in result.all_predictions[:3]:
|
| print(f" - {intent}: {conf:.1%}")
|
|
|