Upload fast parser and ml evaluator
Browse files- fast_parser.py +253 -0
- ml_classifier.py +246 -0
fast_parser.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📄 src/core/intent_parser/fast_parser.py
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, Tuple, Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ParsedIntent:
|
| 9 |
+
"""Универсальный контейнер для распознанного намерения"""
|
| 10 |
+
intent: str
|
| 11 |
+
confidence: float
|
| 12 |
+
original_text: str
|
| 13 |
+
normalized_text: str
|
| 14 |
+
parameters: Dict[str, any]
|
| 15 |
+
source: str = "fast_parser"
|
| 16 |
+
|
| 17 |
+
class FastIntentParser:
|
| 18 |
+
"""
|
| 19 |
+
Быстрый парсер намерений на основе ключевых слов и правил.
|
| 20 |
+
Обрабатывает 80-90% типичных запросов без использования ML.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.logger = logging.getLogger(__name__)
|
| 25 |
+
self._setup_domains()
|
| 26 |
+
self._setup_synonyms()
|
| 27 |
+
self._setup_patterns()
|
| 28 |
+
|
| 29 |
+
def _setup_domains(self):
|
| 30 |
+
"""Настройка доменов и ключевых слов"""
|
| 31 |
+
self.domains = {
|
| 32 |
+
'greeting': {
|
| 33 |
+
'keywords': ['привет', 'здравствуй', 'добрый', 'хай', 'салют', 'здаров'],
|
| 34 |
+
'priority': 1,
|
| 35 |
+
'response_templates': [
|
| 36 |
+
"Привет! Готов к работе.",
|
| 37 |
+
"Здравствуйте! Чем могу помочь?",
|
| 38 |
+
"Приветствую! Ariel на связи."
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
'system': {
|
| 42 |
+
'keywords': ['будильник', 'таймер', 'открой', 'запусти', 'выключи', 'громкость'],
|
| 43 |
+
'priority': 2,
|
| 44 |
+
'subdomains': {
|
| 45 |
+
'alarm': ['будильник', 'разбуди', 'напомни'],
|
| 46 |
+
'app_launch': ['открой', 'запусти', 'включи'],
|
| 47 |
+
'system_control': ['выключи', 'перезагрузи', 'громкость']
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
'visualization': {
|
| 51 |
+
'keywords': ['график', 'диаграмм', 'схем', 'визуализир', 'построй', 'нарисуй'],
|
| 52 |
+
'priority': 3,
|
| 53 |
+
'subdomains': {
|
| 54 |
+
'plot': ['график', 'построй'],
|
| 55 |
+
'chart': ['диаграмм', 'гистограмм'],
|
| 56 |
+
'scheme': ['схем', 'блок-схем']
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
'knowledge': {
|
| 60 |
+
'keywords': ['что такое', 'как работает', 'объясни', 'найди информацию', 'база данных'],
|
| 61 |
+
'priority': 4
|
| 62 |
+
},
|
| 63 |
+
'creative': {
|
| 64 |
+
'keywords': ['расскажи', 'пошути', 'придумай', 'рекомендуй', 'советуй'],
|
| 65 |
+
'priority': 5
|
| 66 |
+
},
|
| 67 |
+
'help': {
|
| 68 |
+
'keywords': ['помощь', 'команды', 'что ты умеешь', 'справка'],
|
| 69 |
+
'priority': 6
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
def _setup_synonyms(self):
|
| 74 |
+
"""Настройка синонимов для нормализации текста"""
|
| 75 |
+
self.synonyms = {
|
| 76 |
+
'привет': ['салют', 'здаров', 'хей', 'хай', 'здорово', 'добрый день'],
|
| 77 |
+
'будильник': ['будильничек', 'напоминание', 'оповещение', 'звонок'],
|
| 78 |
+
'поставь': ['заведи', 'установи', 'создай', 'активируй'],
|
| 79 |
+
'открой': ['запусти', 'включи', 'открой', 'запусти'],
|
| 80 |
+
'график': ['графики', 'графичек', 'плотик'],
|
| 81 |
+
'помощь': ['справка', 'хелп', 'помоги', 'подскажи']
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def _setup_patterns(self):
|
| 85 |
+
"""Настройка regex-паттернов для сложных случаев"""
|
| 86 |
+
self.patterns = {
|
| 87 |
+
'system': [
|
| 88 |
+
# Будильники и таймеры
|
| 89 |
+
r'(поставь|заведи|установи).*будильник.*(\d{1,2}:\d{2})',
|
| 90 |
+
r'будильник.*(\d{1,2}).*(утра|вечера|часов|час)',
|
| 91 |
+
r'разбуди.*(\d{1,2}).*(утра|вечера)',
|
| 92 |
+
# Запуск приложений
|
| 93 |
+
r'(открой|запусти).*(браузер|chrome|хром|firefox|файрфокс)',
|
| 94 |
+
r'(открой|запусти).*(терминал|cmd|командную строку)',
|
| 95 |
+
# Управление системой
|
| 96 |
+
r'(выключи|перезагрузи).*(компьютер|систему)',
|
| 97 |
+
r'(сделай|поставь).*(громче|тише)'
|
| 98 |
+
],
|
| 99 |
+
'visualization': [
|
| 100 |
+
r'построй.*график.*(\w+).*от.*(\d+).*до.*(\d+)',
|
| 101 |
+
r'график.*(sin|синус|cos|косинус|tan|тангенс)',
|
| 102 |
+
r'диаграмм.*(кругова|столбчата|гистограмм)',
|
| 103 |
+
r'нарисуй.*схем.*работы'
|
| 104 |
+
],
|
| 105 |
+
'knowledge': [
|
| 106 |
+
r'что такое (\w+)',
|
| 107 |
+
r'как работает (\w+)',
|
| 108 |
+
r'объясни.*(\w+)',
|
| 109 |
+
r'найди.*информацию.*о (\w+)'
|
| 110 |
+
]
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def normalize_text(self, text: str) -> str:
|
| 114 |
+
"""Нормализация текста: приведение к нижнему регистру и замена синонимов"""
|
| 115 |
+
if not text:
|
| 116 |
+
return ""
|
| 117 |
+
|
| 118 |
+
text_lower = text.lower().strip()
|
| 119 |
+
|
| 120 |
+
# Замена синонимов на основные формы
|
| 121 |
+
for main_word, synonyms in self.synonyms.items():
|
| 122 |
+
for synonym in synonyms:
|
| 123 |
+
if synonym in text_lower:
|
| 124 |
+
text_lower = text_lower.replace(synonym, main_word)
|
| 125 |
+
self.logger.debug(f"Заменен синоним '{synonym}' -> '{main_word}'")
|
| 126 |
+
|
| 127 |
+
return text_lower
|
| 128 |
+
|
| 129 |
+
def extract_parameters(self, domain: str, text: str) -> Dict[str, any]:
|
| 130 |
+
"""Извлечение параметров из текста команды"""
|
| 131 |
+
normalized_text = self.normalize_text(text)
|
| 132 |
+
parameters = {}
|
| 133 |
+
|
| 134 |
+
if domain == 'system':
|
| 135 |
+
# Извлечение времени для будильников
|
| 136 |
+
time_match = re.search(r'(\d{1,2})(?::(\d{2}))?\s*(утра|вечера|часов|час)?', normalized_text)
|
| 137 |
+
if time_match:
|
| 138 |
+
hour = int(time_match.group(1))
|
| 139 |
+
minute = int(time_match.group(2) or "0")
|
| 140 |
+
period = time_match.group(3) or ""
|
| 141 |
+
|
| 142 |
+
# Конвертация в 24-часовой формат
|
| 143 |
+
if period == 'вечера' and hour < 12:
|
| 144 |
+
hour += 12
|
| 145 |
+
|
| 146 |
+
parameters['time'] = f"{hour:02d}:{minute:02d}"
|
| 147 |
+
parameters['period'] = period
|
| 148 |
+
|
| 149 |
+
# Извлечение названия приложения
|
| 150 |
+
app_matches = re.findall(r'(браузер|хром|chrome|терминал|cmd)', normalized_text)
|
| 151 |
+
if app_matches:
|
| 152 |
+
parameters['app'] = app_matches[0]
|
| 153 |
+
|
| 154 |
+
elif domain == 'visualization':
|
| 155 |
+
# Извлечение математической функции
|
| 156 |
+
func_match = re.search(r'(sin|синус|cos|косинус|tan|тангенс|x\^2)', normalized_text)
|
| 157 |
+
if func_match:
|
| 158 |
+
func_map = {'синус': 'sin', 'косинус': 'cos', 'тангенс': 'tan'}
|
| 159 |
+
parameters['function'] = func_map.get(func_match.group(1), func_match.group(1))
|
| 160 |
+
|
| 161 |
+
# Извлечение диапазона
|
| 162 |
+
range_match = re.search(r'от\s*(\d+)\s*до\s*(\d+)', normalized_text)
|
| 163 |
+
if range_match:
|
| 164 |
+
parameters['x_range'] = [float(range_match.group(1)), float(range_match.group(2))]
|
| 165 |
+
|
| 166 |
+
elif domain == 'knowledge':
|
| 167 |
+
# Извлечение темы для поиска
|
| 168 |
+
topic_match = re.search(r'что такое\s+(\w+)', normalized_text)
|
| 169 |
+
if not topic_match:
|
| 170 |
+
topic_match = re.search(r'как работает\s+(\w+)', normalized_text)
|
| 171 |
+
if not topic_match:
|
| 172 |
+
topic_match = re.search(r'объясни\s+(\w+)', normalized_text)
|
| 173 |
+
|
| 174 |
+
if topic_match:
|
| 175 |
+
parameters['topic'] = topic_match.group(1)
|
| 176 |
+
|
| 177 |
+
return parameters
|
| 178 |
+
|
| 179 |
+
def parse(self, text: str) -> Optional[ParsedIntent]:
|
| 180 |
+
"""
|
| 181 |
+
Основной метод парсинга намерения из текста.
|
| 182 |
+
Возвращает ParsedIntent или None если намерение не распознано.
|
| 183 |
+
"""
|
| 184 |
+
if not text or not text.strip():
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
normalized_text = self.normalize_text(text)
|
| 188 |
+
self.logger.debug(f"Парсинг текста: '{text}' -> '{normalized_text}'")
|
| 189 |
+
|
| 190 |
+
# Сначала проверяем regex-паттерны (более точные)
|
| 191 |
+
domain_from_patterns = self._check_patterns(normalized_text)
|
| 192 |
+
if domain_from_patterns:
|
| 193 |
+
domain, subdomain, confidence = domain_from_patterns
|
| 194 |
+
parameters = self.extract_parameters(domain, normalized_text)
|
| 195 |
+
|
| 196 |
+
return ParsedIntent(
|
| 197 |
+
intent=domain,
|
| 198 |
+
confidence=confidence,
|
| 199 |
+
original_text=text,
|
| 200 |
+
normalized_text=normalized_text,
|
| 201 |
+
parameters=parameters
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Затем проверяем ключевые слова
|
| 205 |
+
domain_from_keywords = self._check_keywords(normalized_text)
|
| 206 |
+
if domain_from_keywords:
|
| 207 |
+
domain, subdomain, confidence = domain_from_keywords
|
| 208 |
+
parameters = self.extract_parameters(domain, normalized_text)
|
| 209 |
+
|
| 210 |
+
return ParsedIntent(
|
| 211 |
+
intent=domain,
|
| 212 |
+
confidence=confidence,
|
| 213 |
+
original_text=text,
|
| 214 |
+
normalized_text=normalized_text,
|
| 215 |
+
parameters=parameters
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Не распознано
|
| 219 |
+
self.logger.debug(f"Не удалось распознать намерение: '{text}'")
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
def _check_patterns(self, text: str) -> Optional[Tuple[str, str, float]]:
|
| 223 |
+
"""Проверка текста по regex-паттернам"""
|
| 224 |
+
for domain, pattern_list in self.patterns.items():
|
| 225 |
+
for pattern in pattern_list:
|
| 226 |
+
if re.search(pattern, text):
|
| 227 |
+
self.logger.debug(f"Найден паттерн '{pattern}' для домена '{domain}'")
|
| 228 |
+
return domain, None, 0.95 # Высокая уверенность для паттернов
|
| 229 |
+
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
def _check_keywords(self, text: str) -> Optional[Tuple[str, str, float]]:
|
| 233 |
+
"""Проверка текста по ключевым словам"""
|
| 234 |
+
found_domains = []
|
| 235 |
+
|
| 236 |
+
for domain, domain_config in self.domains.items():
|
| 237 |
+
for keyword in domain_config['keywords']:
|
| 238 |
+
if keyword in text:
|
| 239 |
+
confidence = 0.9 if len(keyword) > 3 else 0.7
|
| 240 |
+
found_domains.append((domain, None, confidence))
|
| 241 |
+
self.logger.debug(f"Найдено ключевое слово '{keyword}' для домена '{domain}'")
|
| 242 |
+
|
| 243 |
+
if not found_domains:
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
# Возвращаем домен с наивысшим приоритетом
|
| 247 |
+
found_domains.sort(key=lambda x: self.domains[x[0]]['priority'])
|
| 248 |
+
return found_domains[0]
|
| 249 |
+
|
| 250 |
+
# Фабрика для создания парсера
|
| 251 |
+
def create_fast_parser() -> FastIntentParser:
|
| 252 |
+
"""Создание и настройка быстрого парсера"""
|
| 253 |
+
return FastIntentParser()
|
ml_classifier.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📄 src/core/intent_parser/ml_classifier.py
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, List, Optional, Any
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from torch.quantization import quantize_dynamic
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
# Импорты с обработкой ошибок
|
| 11 |
+
print("Инициализация ML классификатора...")
|
| 12 |
+
try:
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 16 |
+
ML_AVAILABLE = True
|
| 17 |
+
except ImportError as e:
|
| 18 |
+
print(f"⚠️ ML библиотеки не установлены: {e}")
|
| 19 |
+
ML_AVAILABLE = False
|
| 20 |
+
torch = None
|
| 21 |
+
AutoTokenizer = None
|
| 22 |
+
AutoModelForSequenceClassification = None
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class MLClassificationResult:
|
| 26 |
+
"""Результат классификации ML моделью"""
|
| 27 |
+
intent: str
|
| 28 |
+
confidence: float
|
| 29 |
+
all_predictions: List[tuple] # Список всех (интент, уверенность)
|
| 30 |
+
multi_label_predictions: Optional[List[tuple]] = None # Интенты выше порога
|
| 31 |
+
|
| 32 |
+
class MLIntentClassifier:
|
| 33 |
+
"""
|
| 34 |
+
ML классификатор намерений на основе DistilBERT.
|
| 35 |
+
Поддерживает multi-label классификацию как в обученной модели.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, model_path: Optional[str] = None):
|
| 39 |
+
self.logger = logging.getLogger(__name__)
|
| 40 |
+
self.model = None
|
| 41 |
+
self.tokenizer = None
|
| 42 |
+
self.device = None
|
| 43 |
+
self.is_initialized = False
|
| 44 |
+
|
| 45 |
+
# Словарь интентов
|
| 46 |
+
self.intent_to_idx = {}
|
| 47 |
+
self.idx_to_intent = {}
|
| 48 |
+
|
| 49 |
+
# Настройки
|
| 50 |
+
self.confidence_threshold = 0.3
|
| 51 |
+
self.max_length = 128
|
| 52 |
+
|
| 53 |
+
# Путь к модели (по умолчанию из вашей структуры)
|
| 54 |
+
if model_path is None:
|
| 55 |
+
# Автоматически определяем путь в структуре проекта
|
| 56 |
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 57 |
+
base_dir = "C:/PycharmProjects/Ariel"
|
| 58 |
+
model_path = os.path.join(base_dir, "Data", "Models", "intent_classifier")
|
| 59 |
+
|
| 60 |
+
base_dir = "C:/PycharmProjects/Ariel"
|
| 61 |
+
model_path = os.path.join(base_dir, "Data", "models", "intent_classifier")
|
| 62 |
+
|
| 63 |
+
self.model_path = model_path
|
| 64 |
+
self._initialize_model()
|
| 65 |
+
|
| 66 |
+
def _initialize_model(self):
|
| 67 |
+
"""Инициализация модели с обработкой ошибок"""
|
| 68 |
+
if not ML_AVAILABLE:
|
| 69 |
+
self.logger.warning("ML библиотеки не установлены. Использование заглушки.")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
# Проверяем существование директории
|
| 74 |
+
if not os.path.exists(self.model_path):
|
| 75 |
+
self.logger.error(f"Не найден файл: {self.model_path}")
|
| 76 |
+
self.logger.info("Проверьте, что вы распаковали архив в правильную папку")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
# Проверяем наличие ключевых файлов
|
| 80 |
+
required_files = ['config.json']
|
| 81 |
+
weight_files = ['model.safetensors', 'pytorch_model.bin']
|
| 82 |
+
|
| 83 |
+
for file in required_files:
|
| 84 |
+
if not os.path.exists(os.path.join(self.model_path, file)):
|
| 85 |
+
self.logger.error(f"Не найден файл: {os.path.join(self.model_path, file)}")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
# Проверяем наличие файла весов
|
| 89 |
+
has_weights = any(os.path.exists(os.path.join(self.model_path, wf)) for wf in weight_files)
|
| 90 |
+
if not has_weights:
|
| 91 |
+
self.logger.error(f"Не найден файл весов модели. Ожидается один из: {weight_files}")
|
| 92 |
+
self.logger.info(f"Файлы в директории: {os.listdir(self.model_path)}")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
# Загружаем vocabulary интентов
|
| 96 |
+
vocab_path = os.path.join(self.model_path, "intent_vocab.json")
|
| 97 |
+
if os.path.exists(vocab_path):
|
| 98 |
+
with open(vocab_path, 'r', encoding='utf-8') as f:
|
| 99 |
+
self.intent_to_idx = json.load(f)
|
| 100 |
+
# Преобразуем индексы в int если они строки
|
| 101 |
+
self.intent_to_idx = {k: int(v) for k, v in self.intent_to_idx.items()}
|
| 102 |
+
self.idx_to_intent = {v: k for k, v in self.intent_to_idx.items()}
|
| 103 |
+
self.logger.info(f"Загружен словарь интентов: {len(self.intent_to_idx)} классов")
|
| 104 |
+
else:
|
| 105 |
+
self.logger.warning("Файл intent_vocab.json не найден. Пытаюсь определить из config.json")
|
| 106 |
+
# Попробуем получить из конфига модели
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
# Загружаем модель и токенизатор
|
| 110 |
+
self.logger.info(f"Загрузка модели из {self.model_path}...")
|
| 111 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True)
|
| 112 |
+
|
| 113 |
+
# Загружаем модель с multi-label конфигурацией
|
| 114 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 115 |
+
self.model_path,
|
| 116 |
+
local_files_only=True,
|
| 117 |
+
problem_type="multi_label_classification"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.model = quantize_dynamic(self.model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Настройка устройства
|
| 125 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 126 |
+
self.model.to(self.device)
|
| 127 |
+
self.model.eval()
|
| 128 |
+
|
| 129 |
+
self.is_initialized = True
|
| 130 |
+
self.logger.info(f"✅ Модель загружена успешно!")
|
| 131 |
+
self.logger.info(f" Устройство: {self.device}")
|
| 132 |
+
self.logger.info(f" Классов: {len(self.intent_to_idx) if self.intent_to_idx else 'неизвестно'}")
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
self.logger.error(f"❌ Ошибка загрузки модели: {e}")
|
| 136 |
+
self.is_initialized = False
|
| 137 |
+
|
| 138 |
+
def predict(self, text: str, threshold: Optional[float] = None) -> MLClassificationResult:
|
| 139 |
+
"""Предсказание интентов для текста (multi-label)"""
|
| 140 |
+
if not self.is_initialized:
|
| 141 |
+
self.logger.warning("Модель не инициализирована, возвращаем fallback")
|
| 142 |
+
return self._fallback_prediction(text)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
current_threshold = threshold if threshold is not None else self.confidence_threshold
|
| 146 |
+
|
| 147 |
+
# Токенизация
|
| 148 |
+
inputs = self.tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
truncation=True,
|
| 151 |
+
padding='max_length',
|
| 152 |
+
max_length=self.max_length,
|
| 153 |
+
return_tensors="pt"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Переносим на нужное устройство
|
| 157 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 158 |
+
|
| 159 |
+
# Предсказание
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
outputs = self.model(**inputs)
|
| 162 |
+
# Для multi-label используем sigmoid
|
| 163 |
+
probabilities = torch.sigmoid(outputs.logits)
|
| 164 |
+
|
| 165 |
+
# Получаем numpy массив
|
| 166 |
+
probs = probabilities.cpu().numpy()[0]
|
| 167 |
+
|
| 168 |
+
# Собираем результаты
|
| 169 |
+
all_predictions = []
|
| 170 |
+
multi_label_predictions = []
|
| 171 |
+
|
| 172 |
+
for idx, prob in enumerate(probs):
|
| 173 |
+
if idx in self.idx_to_intent:
|
| 174 |
+
intent_name = self.idx_to_intent[idx]
|
| 175 |
+
confidence = float(prob)
|
| 176 |
+
|
| 177 |
+
all_predictions.append((intent_name, confidence))
|
| 178 |
+
|
| 179 |
+
if confidence >= current_threshold:
|
| 180 |
+
multi_label_predictions.append((intent_name, confidence))
|
| 181 |
+
|
| 182 |
+
# Сортируем по уверенности
|
| 183 |
+
all_predictions.sort(key=lambda x: x[1], reverse=True)
|
| 184 |
+
multi_label_predictions.sort(key=lambda x: x[1], reverse=True)
|
| 185 |
+
|
| 186 |
+
# Определяем основной интент
|
| 187 |
+
main_intent = "unknown"
|
| 188 |
+
main_confidence = 0.0
|
| 189 |
+
|
| 190 |
+
if multi_label_predictions:
|
| 191 |
+
main_intent = multi_label_predictions[0][0]
|
| 192 |
+
main_confidence = multi_label_predictions[0][1]
|
| 193 |
+
elif all_predictions:
|
| 194 |
+
main_intent = all_predictions[0][0]
|
| 195 |
+
main_confidence = all_predictions[0][1]
|
| 196 |
+
|
| 197 |
+
return MLClassificationResult(
|
| 198 |
+
intent=main_intent,
|
| 199 |
+
confidence=main_confidence,
|
| 200 |
+
all_predictions=all_predictions,
|
| 201 |
+
multi_label_predictions=multi_label_predictions
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
self.logger.error(f"Ошибка предсказания: {e}")
|
| 206 |
+
return self._fallback_prediction(text)
|
| 207 |
+
|
| 208 |
+
def _fallback_prediction(self, text: str) -> MLClassificationResult:
|
| 209 |
+
"""Заглушка при ошибках"""
|
| 210 |
+
return MLClassificationResult(
|
| 211 |
+
intent="unknown",
|
| 212 |
+
confidence=0.5,
|
| 213 |
+
all_predictions=[("unknown", 1.0)],
|
| 214 |
+
multi_label_predictions=[]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 218 |
+
"""Информация о модели"""
|
| 219 |
+
return {
|
| 220 |
+
"is_initialized": self.is_initialized,
|
| 221 |
+
"model_path": self.model_path,
|
| 222 |
+
"num_intents": len(self.intent_to_idx),
|
| 223 |
+
"intents": list(self.intent_to_idx.keys()) if self.intent_to_idx else [],
|
| 224 |
+
"confidence_threshold": self.confidence_threshold,
|
| 225 |
+
"device": str(self.device) if self.device else None
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
def create_ml_classifier(model_path: Optional[str] = None) -> MLIntentClassifier:
|
| 229 |
+
"""Фабричная функция для создания классификатора"""
|
| 230 |
+
return MLIntentClassifier(model_path)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# Пример использования (раскомментировать):
|
| 235 |
+
start = time.time()
|
| 236 |
+
print("Загрузка модели, ожидайте...")
|
| 237 |
+
classifier = create_ml_classifier("/Data/Models/intent_classifier")
|
| 238 |
+
|
| 239 |
+
print("✅ Модель загружена! Тестируйте:")
|
| 240 |
+
while True:
|
| 241 |
+
text = input("\nВведите текст: ")
|
| 242 |
+
if text.lower() == 'выход': break
|
| 243 |
+
result = classifier.predict(text)
|
| 244 |
+
print(f"Результат: {result.intent} ({result.confidence:.1%})")
|
| 245 |
+
for intent, conf in result.all_predictions[:3]:
|
| 246 |
+
print(f" - {intent}: {conf:.1%}")
|