File size: 9,281 Bytes
c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec 359319d 84c6dec c1ca391 84c6dec c1ca391 84c6dec c1ca391 84c6dec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import time
from typing import Dict, Tuple
import warnings
warnings.filterwarnings('ignore')
# Конфигурация
MAX_LENGTH = 1000
MODELS = {
"cointegrated/rubert-tiny2": "Лёгкая модель (быстрая)",
"s-nlp/rubert-tiny-cased-rured": "Специализированная для классификации",
"ai-forever/ruBert-base": "Точная модель (медленнее)"
}
LABELS = {
0: "Политика",
1: "Экономика",
2: "Наука и технологии",
3: "Культура и искусство",
4: "Спорт",
5: "Здоровье и медицина",
6: "Образование",
7: "Разное"
}
class TopicClassifier:
def __init__(self):
self.models: Dict = {}
self.tokenizers: Dict = {}
def load_model(self, model_name: str):
"""Загрузка модели по требованию"""
if model_name not in self.models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=len(LABELS)
)
# Настройка для CPU
model.eval()
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
print(f"Модель {model_name} загружена успешно")
except Exception as e:
raise Exception(f"Ошибка загрузки модели: {str(e)}")
def predict(self, text: str, model_name: str) -> Tuple[Dict, float]:
"""Предсказание темы текста"""
if not text.strip():
raise ValueError("Текст не может быть пустым")
if len(text) > MAX_LENGTH:
text = text[:MAX_LENGTH]
gr.Warning(f"Текст обрезан до {MAX_LENGTH} символов")
self.load_model(model_name)
start_time = time.time()
try:
tokenizer = self.tokenizers[model_name]
model = self.models[model_name]
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
scores = predictions[0].tolist()
results = {LABELS[i]: round(score * 100, 2) for i, score in enumerate(scores)}
# Сортировка по уверенности
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
latency = round((time.time() - start_time) * 1000, 2) # в мс
return sorted_results, latency
except Exception as e:
raise Exception(f"Ошибка при обработке: {str(e)}")
# Инициализация классификатора
classifier = TopicClassifier()
def process_text(text: str, model_choice: str) -> Tuple[str, str, str]:
"""Обработка текста с выбранной моделью"""
if not text.strip():
return "⚠️ Введите текст для анализа", "", "0"
try:
predictions, latency = classifier.predict(text, model_choice)
# Форматирование результатов
top_topic = list(predictions.keys())[0]
top_score = predictions[top_topic]
result_text = f"🎯 **Основная тема:** {top_topic} ({top_score}%)\n\n"
result_text += "📊 **Распределение тем:**\n"
for topic, score in predictions.items():
result_text += f"• {topic}: {score}%\n"
# Подготовка JSON для отладки
json_output = "{\n"
for topic, score in predictions.items():
json_output += f' "{topic}": {score},\n'
json_output = json_output.rstrip(",\n") + "\n}"
return result_text, json_output, str(latency)
except ValueError as e:
return f"❌ {str(e)}", "", "0"
except Exception as e:
return f"⚠️ Ошибка: {str(e)}", "", "0"
# Примеры текстов
examples = [
["Российская экономика показала рост в третьем квартале благодаря увеличению экспорта нефти и газа."],
["Ученые создали новый материал для солнечных батарей с эффективностью 45%."],
["На чемпионате мира по футболу сборная Бразилии одержала победу со счетом 3:1."],
["В музее открылась выставка современных художников, посвященная проблемам экологии."],
["Минздрав рекомендовал новые правила вакцинации для населения старше 60 лет."]
]
# Создание интерфейса
with gr.Blocks(title="Классификатор тем текста", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 Классификатор тематики текста")
gr.Markdown("Определите основную тему вашего текста с помощью ИИ-моделей")
with gr.Row():
with gr.Column(scale=2):
model_selector = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="📋 Выберите модель",
info="Каждая модель имеет разный баланс скорости и точности"
)
text_input = gr.Textbox(
label="📝 Введите текст для анализа",
placeholder="Введите текст на русском языке...",
lines=5,
max_lines=10
)
process_btn = gr.Button("🔍 Анализировать текст", variant="primary")
gr.Markdown("### 📋 Примеры текстов")
gr.Examples(
examples=examples,
inputs=text_input,
label="Нажмите на пример для быстрой загрузки"
)
with gr.Column(scale=3):
with gr.Row():
latency_display = gr.Textbox(
label="⏱️ Время обработки",
value="0",
interactive=False
)
latency_display.info = "мсек"
output_text = gr.Markdown(
label="📊 Результаты классификации"
)
json_output = gr.Code(
label="📄 JSON-формат результатов",
language="json",
interactive=False
)
# Обработка событий
process_btn.click(
fn=process_text,
inputs=[text_input, model_selector],
outputs=[output_text, json_output, latency_display]
)
# Дополнительная информация
with gr.Accordion("ℹ️ Информация о моделях", open=False):
gr.Markdown("""
**Доступные модели:**
1. **cointegrated/rubert-tiny2** - Быстрая и легкая модель, идеально подходит для CPU
2. **s-nlp/rubert-tiny-cased-rured** - Специализирована для тематической классификации
3. **ai-forever/ruBert-base** - Самая точная, но требует больше времени
**Ограничения:**
- Максимальная длина текста: 1000 символов
- Только русский язык
- Автоматическое определение 8 основных тем
""")
gr.Markdown("---")
gr.Markdown("### 📌 Инструкция")
gr.Markdown("""
1. Выберите модель из списка
2. Введите или вставьте текст для анализа
3. Нажмите кнопку "Анализировать текст"
4. Получите результаты классификации и время обработки
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |