Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import re | |
| from typing import Dict, List, Any, Optional | |
| from collections import defaultdict, Counter | |
| import networkx as nx | |
| import pymorphy3 | |
| import requests | |
| from fastapi import FastAPI, Request, Form, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| import uvicorn | |
| from transformers import BertTokenizer, BertModel | |
| from sklearn.preprocessing import LabelEncoder | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Используется устройство: {device}") | |
| def clean_russian_text(text): | |
| if not isinstance(text, str): | |
| return "" | |
| text = text.lower() | |
| text = re.sub(r'http\S+|www\S+|https\S+', '', text) | |
| text = re.sub(r'\S+@\S+', '', text) | |
| smileys = { | |
| ':)': ' смайлик_радость ', ')': ' смайлик_радость ', | |
| ':(': ' смайлик_грусть ', '(': ' смайлик_грусть ', | |
| ':D': ' смайлик_смех ', ';)': ' смайлик_подмигивание ', | |
| } | |
| for smiley, replacement in smileys.items(): | |
| text = text.replace(smiley, replacement) | |
| text = re.sub(r'[^\w\sа-яё.,!?;:)(-]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| # ============================================================ | |
| # ПОЛНЫЙ КЛАСС ОНТОЛОГИИ (исправленный) | |
| # ============================================================ | |
| class OntologyEmotionModel: | |
| def __init__(self, emotions: List[str], train_texts: List[str] = None, train_labels: List[int] = None): | |
| self.emotions = emotions | |
| self.morph = pymorphy3.MorphAnalyzer() | |
| self.ontology_graph = nx.DiGraph() | |
| self.empirical_base = defaultdict(list) | |
| self.hypotheses_db = {} | |
| self.verified_hypotheses = defaultdict(list) | |
| self.sentiment_lexicon = {} | |
| self.rule_stats = {} | |
| if train_texts is not None and train_labels is not None: | |
| self._build_sentiment_lexicon(train_texts, train_labels) | |
| self._load_rusentilex() | |
| self.init_ontology_level1() | |
| self.init_ontology_level2() | |
| def _build_sentiment_lexicon(self, texts: List[str], labels: List[int]): | |
| word_class_counts = defaultdict(lambda: np.zeros(len(self.emotions))) | |
| for text, label in zip(texts, labels): | |
| words = set(clean_russian_text(text).split()) | |
| for word in words: | |
| lemma = self.morph.parse(word)[0].normal_form | |
| word_class_counts[lemma][label] += 1 | |
| for lemma, counts in word_class_counts.items(): | |
| prob = counts / (counts.sum() + 1e-10) | |
| if prob.max() > 0.6 and counts.sum() > 5: | |
| dominant_class = self.emotions[np.argmax(prob)] | |
| self.sentiment_lexicon[lemma] = dominant_class | |
| def _parse_rusentilex(self, content): | |
| lines = content.splitlines() | |
| added = 0 | |
| for line in lines[1:]: # пропускаем заголовок | |
| try: | |
| parts = line.strip().split(',') | |
| if len(parts) >= 3: | |
| word = parts[0].strip().lower() | |
| sentiment = parts[2].strip().lower() | |
| lemma = self.morph.parse(word)[0].normal_form | |
| if sentiment == 'positive': | |
| self.sentiment_lexicon[lemma] = 'радость' | |
| added += 1 | |
| elif sentiment == 'negative': | |
| self.sentiment_lexicon[lemma] = 'грусть' | |
| added += 1 | |
| except Exception as e: | |
| continue | |
| print(f" Добавлено слов из RuSentiLex: {added}") | |
| def _load_rusentilex(self): | |
| """Загружает RuSentiLex из локального файла в папке model""" | |
| import os | |
| # Пути для поиска файла RuSentiLex | |
| possible_paths = [ | |
| 'model/rusentilex.csv', | |
| 'rusentilex.csv', | |
| '/app/model/rusentilex.csv', | |
| os.path.join(os.path.dirname(__file__), 'model', 'rusentilex.csv') | |
| ] | |
| loaded = False | |
| print("📂 Поиск RuSentiLex...") | |
| # Пробуем загрузить из локального файла | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| try: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| self._parse_rusentilex(content) | |
| print(f"✅ RuSentiLex загружен из файла: {path}") | |
| loaded = True | |
| break | |
| except Exception as e: | |
| print(f"⚠️ Ошибка при загрузке {path}: {e}") | |
| # Если локально не нашли, пробуем скачать из интернета | |
| if not loaded: | |
| print("⚠️ Локальный файл RuSentiLex не найден, пробуем скачать...") | |
| url = "https://raw.githubusercontent.com/nicolay-r/sentiment-relation-classifiers/master/data/rusentilex.csv" | |
| try: | |
| r = requests.get(url, timeout=10) | |
| if r.status_code == 200: | |
| self._parse_rusentilex(r.text) | |
| print("✅ RuSentiLex загружен из репозитория") | |
| loaded = True | |
| except Exception as e: | |
| print(f"⚠️ Не удалось загрузить RuSentiLex из репозитория: {e}") | |
| if not loaded: | |
| print("⚠️ RuSentiLex не загружен. Используется только статистический лексикон.") | |
| # Выводим статистику | |
| print(f"📊 Всего слов в лексиконе: {len(self.sentiment_lexicon)}") | |
| def init_ontology_level1(self): | |
| self.emotion_definitions = { | |
| 'радость': { | |
| 'valence': 'positive', 'arousal': 'high', | |
| 'definition': 'Позитивное эмоциональное состояние', | |
| 'opposite': ['грусть', 'злость'] | |
| }, | |
| 'грусть': { | |
| 'valence': 'negative', 'arousal': 'low', | |
| 'definition': 'Негативное эмоциональное состояние', | |
| 'opposite': ['радость'] | |
| }, | |
| 'злость': { | |
| 'valence': 'negative', 'arousal': 'high', | |
| 'definition': 'Негативное эмоциональное состояние', | |
| 'opposite': ['радость'] | |
| }, | |
| 'страх': { | |
| 'valence': 'negative', 'arousal': 'high', | |
| 'definition': 'Эмоциональная реакция на угрозу', | |
| 'opposite': ['уверенность', 'спокойствие'] | |
| }, | |
| 'сарказм': { | |
| 'valence': 'negative', 'arousal': 'high', | |
| 'definition': 'Язвительная насмешка', | |
| 'opposite': ['радость'] | |
| } | |
| } | |
| for emotion in self.emotions: | |
| if emotion in self.emotion_definitions: | |
| self.ontology_graph.add_node(emotion, **self.emotion_definitions[emotion]) | |
| else: | |
| self.ontology_graph.add_node(emotion, valence='neutral', arousal='neutral') | |
| for emotion, data in self.emotion_definitions.items(): | |
| if 'opposite' in data: | |
| for opposite in data['opposite']: | |
| if opposite in self.emotions: | |
| self.ontology_graph.add_edge(emotion, opposite, relation='opposite') | |
| def init_ontology_level2(self): | |
| self.linguistic_rules = { | |
| 'усилители': { | |
| 'words': ['очень', 'сильно', 'крайне', 'чрезвычайно', 'невероятно', 'абсолютно'], | |
| 'effect': 'increase_arousal', | |
| 'weight': 0.3, | |
| 'learnable': True | |
| }, | |
| 'ослабители': { | |
| 'words': ['слегка', 'немного', 'чуть-чуть', 'отчасти', 'несколько'], | |
| 'effect': 'decrease_arousal', | |
| 'weight': -0.2, | |
| 'learnable': True | |
| }, | |
| 'отрицания': { | |
| 'words': ['не', 'ни', 'нет', 'нельзя', 'невозможно'], | |
| 'effect': 'negation', | |
| 'weight': -0.5, | |
| 'learnable': True | |
| }, | |
| 'восклицания': { | |
| 'patterns': [r'!+', r'\?+'], | |
| 'effect': 'increase_arousal', | |
| 'weight': 0.4, | |
| 'learnable': True | |
| }, | |
| 'вопросительные': { | |
| 'patterns': [r'\?+'], | |
| 'effect': 'uncertainty', | |
| 'weight': 0.2, | |
| 'learnable': True | |
| }, | |
| 'сарказм_маркеры': { | |
| 'words': ['какой', 'такой', 'прям', 'ага', 'ну да', 'конечно', 'отличная работа', 'прекрасно', 'замечательно', 'как всегда'], | |
| 'effect': 'sarcasm', | |
| 'weight': 0.6, | |
| 'learnable': True | |
| } | |
| } | |
| def add_empirical_knowledge(self, text: str, emotion: str, confidence: float): | |
| self.empirical_base[emotion].append({'text': text, 'confidence': confidence}) | |
| if len(self.empirical_base[emotion]) > 1000: | |
| self.empirical_base[emotion] = self.empirical_base[emotion][-1000:] | |
| def formulate_hypothesis(self, text: str, model_prediction: Dict, rule_based_prediction: Dict) -> Dict: | |
| hypothesis_id = f"hyp_{len(self.hypotheses_db) + 1:06d}" | |
| hypothesis = { | |
| 'id': hypothesis_id, 'text': text, | |
| 'model_prediction': model_prediction, | |
| 'rule_based_prediction': rule_based_prediction, | |
| 'disagreement': self.calculate_disagreement(model_prediction, rule_based_prediction), | |
| 'status': 'pending' | |
| } | |
| self.hypotheses_db[hypothesis_id] = hypothesis | |
| return hypothesis | |
| def verify_hypothesis(self, hypothesis_id: str, actual_emotion: str = None) -> Dict: | |
| if hypothesis_id not in self.hypotheses_db: | |
| return None | |
| hypothesis = self.hypotheses_db[hypothesis_id] | |
| if actual_emotion: | |
| model_correct = hypothesis['model_prediction']['emotion'] == actual_emotion | |
| rule_correct = hypothesis['rule_based_prediction']['emotion'] == actual_emotion | |
| if model_correct and not rule_correct: | |
| hypothesis['status'] = 'model_superior' | |
| elif rule_correct and not model_correct: | |
| hypothesis['status'] = 'rule_superior' | |
| elif model_correct and rule_correct: | |
| hypothesis['status'] = 'both_correct' | |
| else: | |
| hypothesis['status'] = 'both_incorrect' | |
| return hypothesis | |
| def apply_linguistic_rules(self, text: str) -> Dict: | |
| rules_applied = [] | |
| adjustments = {'valence': 0, 'arousal': 0, 'uncertainty': 0, 'sarcasm': 0} | |
| words = text.lower().split() | |
| parsed = [self.morph.parse(w)[0] for w in words] | |
| lemmas = [p.normal_form for p in parsed] | |
| pos_tags = [p.tag.POS for p in parsed] | |
| # Проверка на слова из лексикона | |
| for lemma in lemmas: | |
| sentiment = self.sentiment_lexicon.get(lemma, 'neutral') | |
| if sentiment == 'радость': | |
| rules_applied.append(f"позитивное слово: {lemma}") | |
| adjustments['valence'] += 0.2 | |
| elif sentiment in ('грусть', 'злость', 'страх'): | |
| rules_applied.append(f"негативное слово: {lemma}") | |
| adjustments['valence'] -= 0.2 | |
| for category, rule in self.linguistic_rules.items(): | |
| if 'words' in rule: | |
| for word in rule['words']: | |
| if word in lemmas: | |
| rules_applied.append(f"{category}: {word}") | |
| effect = rule['effect'] | |
| weight = rule['weight'] | |
| if effect == 'increase_arousal': | |
| adjustments['arousal'] += weight | |
| elif effect == 'decrease_arousal': | |
| adjustments['arousal'] += weight | |
| elif effect == 'negation': | |
| adjustments['valence'] += weight | |
| elif effect == 'sarcasm': | |
| adjustments['sarcasm'] += weight | |
| if 'patterns' in rule: | |
| for pattern in rule['patterns']: | |
| if re.search(pattern, text): | |
| rules_applied.append(f"{category}: {pattern}") | |
| weight = rule['weight'] | |
| if rule['effect'] == 'increase_arousal': | |
| adjustments['arousal'] += weight | |
| elif rule['effect'] == 'uncertainty': | |
| adjustments['uncertainty'] += weight | |
| if 'не' in lemmas: | |
| idx = lemmas.index('не') | |
| if idx + 1 < len(lemmas) and lemmas[idx+1] == 'очень': | |
| adjustments['arousal'] -= 0.2 | |
| adjustments['valence'] -= 0.1 | |
| rules_applied.append("сочетание: не очень") | |
| else: | |
| for j in range(idx+1, min(idx+4, len(lemmas))): | |
| if pos_tags[j] in ('ADJF', 'ADJS', 'ADVB'): | |
| target_word = lemmas[j] | |
| sentiment = self.sentiment_lexicon.get(target_word, 'neutral') | |
| if sentiment in ('грусть', 'злость', 'страх'): | |
| adjustments['valence'] += 1.0 | |
| rules_applied.append(f"инверсия негатива: не {target_word}") | |
| elif sentiment == 'радость': | |
| adjustments['valence'] -= 1.0 | |
| rules_applied.append(f"инверсия позитива: не {target_word}") | |
| break | |
| pos_words = [w for w in lemmas if self.sentiment_lexicon.get(w) == 'радость'] | |
| neg_words = [w for w in lemmas if self.sentiment_lexicon.get(w) in ('грусть', 'злость', 'страх')] | |
| if pos_words and neg_words: | |
| adjustments['sarcasm'] += 0.5 | |
| rules_applied.append(f"контраст тональности: позитив {pos_words[:2]} vs негатив {neg_words[:2]}") | |
| # Дополнительная проверка на саркастические фразы | |
| sarcasm_phrases = ['конечно', 'ага', 'ну да', 'как всегда', 'отличная работа', 'прекрасно', 'замечательно'] | |
| for phrase in sarcasm_phrases: | |
| if phrase in text.lower(): | |
| adjustments['sarcasm'] += 0.6 | |
| rules_applied.append(f"саркастическая фраза: {phrase}") | |
| if adjustments['sarcasm'] > 0.5: | |
| rules_applied.append("обнаружен сарказм") | |
| return {'rules_applied': rules_applied, 'adjustments': adjustments, 'lemmas': lemmas} | |
| def calculate_disagreement(self, pred1: Dict, pred2: Dict) -> float: | |
| if pred1['emotion'] == pred2['emotion']: | |
| return 0.0 | |
| emotions = list(self.emotion_definitions.keys()) | |
| idx1 = emotions.index(pred1['emotion']) if pred1['emotion'] in emotions else -1 | |
| idx2 = emotions.index(pred2['emotion']) if pred2['emotion'] in emotions else -1 | |
| if idx1 == -1 or idx2 == -1: | |
| return 0.5 | |
| distance = abs(idx1 - idx2) / len(emotions) | |
| return 0.7 * distance | |
| def explain_transition(self, from_emotion: str, to_emotion: str) -> List[str]: | |
| try: | |
| return nx.shortest_path(self.ontology_graph, source=from_emotion, target=to_emotion) | |
| except: | |
| return [] | |
| def adjust_prediction_with_rules(self, prediction: Dict, rule_analysis: Dict) -> Dict: | |
| original_emotion = prediction['emotion'] | |
| original_confidence = prediction['confidence'] | |
| adj = rule_analysis['adjustments'] | |
| rules = rule_analysis['rules_applied'] | |
| # Сохраняем исходную уверенность для проверки коррекции | |
| original_confidence_value = original_confidence | |
| was_corrected = len(rules) > 0 | |
| conf_mult = 1.0 + adj['arousal'] * 0.2 + adj['uncertainty'] * 0.1 - abs(adj['valence']) * 0.1 | |
| conf_mult = np.clip(conf_mult, 0.5, 1.5) | |
| new_confidence = original_confidence * conf_mult | |
| new_emotion = original_emotion | |
| # Если есть негативные слова и нет позитивных, корректируем эмоцию | |
| has_negative = any('негативное слово' in r for r in rules) | |
| has_positive = any('позитивное слово' in r for r in rules) | |
| if has_negative and not has_positive: | |
| if original_emotion == 'радость': | |
| new_emotion = 'грусть' | |
| new_confidence *= 0.8 | |
| rules.append("коррекция: негативные слова без позитивных") | |
| elif original_emotion == 'сарказм': | |
| new_emotion = 'грусть' | |
| new_confidence *= 0.9 | |
| elif has_positive and not has_negative and original_emotion in ('грусть', 'злость', 'страх'): | |
| new_emotion = 'радость' | |
| rules.append("коррекция: позитивные слова") | |
| # Инверсия на основе правил | |
| for rule in rules: | |
| if rule.startswith("инверсия негатива:"): | |
| new_emotion = 'радость' | |
| break | |
| elif rule.startswith("инверсия позитива:"): | |
| if adj['arousal'] > 0.3: | |
| new_emotion = 'злость' | |
| else: | |
| new_emotion = 'грусть' | |
| break | |
| # Сарказм (контраст + маркеры) | |
| sarcasm_flag = adj['sarcasm'] > 0.5 | |
| if sarcasm_flag: | |
| new_emotion = 'сарказм' | |
| new_confidence = min(new_confidence * 0.8, 0.9) | |
| if "саркастическая фраза" in str(rules): | |
| new_confidence = min(new_confidence * 1.1, 0.95) | |
| # Восклицания | |
| if any('восклицание' in r for r in rules): | |
| new_confidence = min(new_confidence * 1.2, 1.0) | |
| # Если онтология не применила коррекции, а уверенность была менее 90%, | |
| # то повышаем уверенность на 10% (но не более 100%) | |
| if not was_corrected and original_confidence_value < 0.9: | |
| new_confidence = min(new_confidence * 1.10, 1.0) | |
| # Ограничиваем максимум 1.0 (100%) | |
| new_confidence = min(new_confidence, 1.0) | |
| return { | |
| 'emotion': new_emotion, | |
| 'confidence': new_confidence, | |
| 'rules_applied': rules | |
| } | |
| def get_ontology_analysis(self, text: str, model_prediction: Dict) -> Dict: | |
| rule_analysis = self.apply_linguistic_rules(text) | |
| adjusted = self.adjust_prediction_with_rules(model_prediction, rule_analysis) | |
| disagreement = self.calculate_disagreement(model_prediction, adjusted) | |
| hypothesis = self.formulate_hypothesis(text, model_prediction, adjusted) if disagreement > 0.2 else None | |
| return { | |
| 'rule_analysis': rule_analysis, | |
| 'adjusted_prediction': adjusted, | |
| 'disagreement': disagreement, | |
| 'hypothesis': hypothesis | |
| } | |
| def get_statistics(self) -> Dict: | |
| return { | |
| 'ontology_nodes': len(self.ontology_graph.nodes), | |
| 'ontology_edges': len(self.ontology_graph.edges), | |
| 'linguistic_rules': len(self.linguistic_rules), | |
| 'emotions_covered': len(self.emotions), | |
| 'pending_hypotheses': len([h for h in self.hypotheses_db.values() if h['status'] == 'pending']) | |
| } | |
| # ============================================================ | |
| # КЛАССЫ МОДЕЛЕЙ LSTM и BERT | |
| # ============================================================ | |
| class EmotionLSTM(nn.Module): | |
| def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_classes=3, dropout=0.3, num_layers=2): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim * 2, 128), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, num_classes) | |
| ) | |
| def forward(self, x, return_confidence=False): | |
| embedded = self.embedding(x) | |
| lstm_out, (hidden, cell) = self.lstm(embedded) | |
| lstm_last = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) | |
| features = self.dropout(lstm_last) | |
| logits = self.classifier(features) | |
| if return_confidence: | |
| probs = torch.softmax(logits, dim=1) | |
| conf, _ = torch.max(probs, dim=1) | |
| return logits, conf | |
| return logits | |
| class EmotionBERT(nn.Module): | |
| def __init__(self, bert_model_name, num_classes, dropout=0.3): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained(bert_model_name) | |
| hidden = self.bert.config.hidden_size | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout), nn.Linear(hidden, 256), nn.ReLU(), | |
| nn.Dropout(dropout), nn.Linear(256, 128), nn.ReLU(), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, input_ids, attention_mask, return_confidence=False): | |
| out = self.bert(input_ids, attention_mask, return_dict=True) | |
| cls = out.last_hidden_state[:, 0, :] | |
| logits = self.classifier(cls) | |
| if return_confidence: | |
| probs = torch.softmax(logits, dim=1) | |
| conf, _ = torch.max(probs, dim=1) | |
| return logits, conf | |
| return logits | |
| # ============================================================ | |
| # КАСКАДНЫЙ КЛАССИФИКАТОР | |
| # ============================================================ | |
| class CascadeEmotionClassifier: | |
| def __init__(self, lstm_model, bert_model, vocab, tokenizer, label_encoder, ontology_model, threshold=0.95, device='cpu', max_length_lstm=100, max_length_bert=128): | |
| self.lstm_model = lstm_model | |
| self.bert_model = bert_model | |
| self.vocab = vocab | |
| self.tokenizer = tokenizer | |
| self.label_encoder = label_encoder | |
| self.ontology_model = ontology_model | |
| self.threshold = threshold | |
| self.device = device | |
| self.max_length_lstm = max_length_lstm | |
| self.max_length_bert = max_length_bert | |
| self.lstm_model.eval() | |
| self.bert_model.eval() | |
| self.lstm_model.to(device) | |
| self.bert_model.to(device) | |
| self.stats = {'total': 0, 'lstm': 0, 'bert': 0, 'corrections': 0} | |
| def text_to_sequence(self, text): | |
| words = str(text).split()[:self.max_length_lstm] | |
| sequence = [self.vocab.get(word, self.vocab.get('<UNK>', 1)) for word in words] | |
| if len(sequence) < self.max_length_lstm: | |
| sequence += [self.vocab.get('<PAD>', 0)] * (self.max_length_lstm - len(sequence)) | |
| return sequence[:self.max_length_lstm] | |
| def predict(self, text): | |
| self.stats['total'] += 1 | |
| text_clean = clean_russian_text(text) | |
| seq = torch.LongTensor([self.text_to_sequence(text_clean)]).to(self.device) | |
| with torch.no_grad(): | |
| lstm_logits, lstm_conf = self.lstm_model(seq, return_confidence=True) | |
| lstm_probs = torch.softmax(lstm_logits, dim=1) | |
| lstm_pred = lstm_probs.argmax().item() | |
| lstm_emo = self.label_encoder.inverse_transform([lstm_pred])[0] | |
| lstm_pred_dict = {'emotion': lstm_emo, 'confidence': lstm_conf.item(), 'probabilities': lstm_probs[0].cpu().numpy().tolist()} | |
| # Применяем онтологию к LSTM | |
| lstm_onto = self.ontology_model.get_ontology_analysis(text_clean, lstm_pred_dict) | |
| if lstm_onto['adjusted_prediction']['confidence'] >= self.threshold: | |
| self.stats['lstm'] += 1 | |
| final = lstm_onto['adjusted_prediction'] | |
| used = "LSTM + онтология" | |
| rules_applied = lstm_onto['rule_analysis']['rules_applied'] | |
| else: | |
| self.stats['bert'] += 1 | |
| enc = self.tokenizer(text_clean, truncation=True, padding=True, max_length=self.max_length_bert, return_tensors='pt').to(self.device) | |
| with torch.no_grad(): | |
| bert_logits, bert_conf = self.bert_model(enc['input_ids'], enc['attention_mask'], return_confidence=True) | |
| bert_probs = torch.softmax(bert_logits, dim=1) | |
| bert_pred = bert_probs.argmax().item() | |
| bert_emo = self.label_encoder.inverse_transform([bert_pred])[0] | |
| bert_pred_dict = {'emotion': bert_emo, 'confidence': bert_conf.item(), 'probabilities': bert_probs[0].cpu().numpy().tolist()} | |
| # Применяем онтологию к BERT | |
| bert_onto = self.ontology_model.get_ontology_analysis(text_clean, bert_pred_dict) | |
| final = bert_onto['adjusted_prediction'] | |
| used = "BERT + онтология" | |
| rules_applied = bert_onto['rule_analysis']['rules_applied'] | |
| return { | |
| 'text': text, | |
| 'predicted_emotion': final['emotion'], | |
| 'confidence': float(final['confidence']), | |
| 'used_model': used, | |
| 'rules_applied': rules_applied, | |
| 'was_corrected_by_ontology': len(rules_applied) > 0 | |
| } | |
| # ============================================================ | |
| # ЗАГРУЗКА МОДЕЛИ | |
| # ============================================================ | |
| def load_model(): | |
| print("Загрузка модели...") | |
| model_dir = 'model' | |
| # Загружаем информацию о модели | |
| with open(f'{model_dir}/model_info.json', 'r', encoding='utf-8') as f: | |
| model_info = json.load(f) | |
| # Загружаем vocab | |
| with open(f'{model_dir}/vocab.json', 'r', encoding='utf-8') as f: | |
| vocab = json.load(f) | |
| # СОЗДАЁМ label_encoder из model_info | |
| print("📂 Создание label_encoder...") | |
| label_encoder = LabelEncoder() | |
| label_encoder.classes_ = np.array(model_info['classes']) | |
| print(f"✅ label_encoder создан, классы: {list(label_encoder.classes_)}") | |
| # СОЗДАЁМ онтологию | |
| print("📂 Создание онтологии...") | |
| ontology_model = OntologyEmotionModel( | |
| emotions=list(label_encoder.classes_), | |
| train_texts=None, | |
| train_labels=None | |
| ) | |
| print("✅ Онтология создана") | |
| # LSTM | |
| print("📂 Загрузка LSTM...") | |
| lstm_model = EmotionLSTM( | |
| vocab_size=len(vocab), | |
| embed_dim=model_info.get('embed_dim', 300), | |
| hidden_dim=256, | |
| num_classes=model_info['num_classes'], | |
| dropout=0.3, | |
| num_layers=2 | |
| ) | |
| checkpoint = torch.load(f'{model_dir}/lstm_model.pth', map_location=device, weights_only=False) | |
| lstm_model.load_state_dict(checkpoint['model_state_dict']) | |
| print("✅ LSTM загружена") | |
| # BERT | |
| print("📂 Загрузка BERT...") | |
| bert_model = EmotionBERT( | |
| bert_model_name=model_info['bert_model_name'], | |
| num_classes=model_info['num_classes'], | |
| dropout=0.3 | |
| ) | |
| bert_model.load_state_dict(torch.load(f'{model_dir}/bert_model.pth', map_location=device, weights_only=False)) | |
| print("✅ BERT загружена") | |
| # Токенизатор | |
| print("📂 Загрузка токенизатора...") | |
| try: | |
| tokenizer = BertTokenizer.from_pretrained(model_dir) | |
| print("✅ Токенизатор загружен из model_dir") | |
| except Exception as e: | |
| print(f"⚠️ Ошибка: {e}") | |
| print("🔄 Загружаем токенизатор из Hugging Face...") | |
| tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased') | |
| print("✅ Токенизатор загружен из Hugging Face") | |
| # Каскад | |
| print("📂 Создание каскадного классификатора...") | |
| cascade = CascadeEmotionClassifier( | |
| lstm_model=lstm_model, | |
| bert_model=bert_model, | |
| vocab=vocab, | |
| tokenizer=tokenizer, | |
| label_encoder=label_encoder, | |
| ontology_model=ontology_model, | |
| threshold=model_info.get('threshold', 0.95), | |
| device=device, | |
| max_length_lstm=model_info.get('max_length_lstm', 100), | |
| max_length_bert=model_info.get('max_length_bert', 128) | |
| ) | |
| print("✅ Модель успешно загружена!") | |
| return cascade, model_info | |
| # ============================================================ | |
| # FASTAPI ПРИЛОЖЕНИЕ | |
| # ============================================================ | |
| app = FastAPI(title="Emotion Analysis with BERT and Ontology") | |
| templates = Jinja2Templates(directory="templates") | |
| classifier = None | |
| model_info = None | |
| async def startup_event(): | |
| global classifier, model_info | |
| classifier, model_info = load_model() | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def predict(text: str = Form(...)): | |
| if not classifier: | |
| raise HTTPException(status_code=503, detail="Модель не загружена") | |
| if not text or len(text.strip()) < 3: | |
| return JSONResponse({"error": "Введите хотя бы 3 символа."}, status_code=400) | |
| try: | |
| result = classifier.predict(text) | |
| # Форматируем правила для отображения | |
| rules_display = [] | |
| for rule in result['rules_applied'][:10]: | |
| if ':' in rule: | |
| cat, val = rule.split(':', 1) | |
| rules_display.append(f"<span class='rule-tag'>{cat}: {val}</span>") | |
| else: | |
| rules_display.append(f"<span class='rule-tag'>{rule}</span>") | |
| return JSONResponse({ | |
| "success": True, | |
| "emotion": result['predicted_emotion'], | |
| "confidence": f"{result['confidence']*100:.1f}%", | |
| "used_model": result['used_model'], | |
| "rules": "".join(rules_display) if rules_display else "Нет правил", | |
| "was_corrected": str(result['was_corrected_by_ontology']) | |
| }) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}, status_code=500) | |
| async def health_check(): | |
| return {"status": "healthy", "model_loaded": classifier is not None} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |