| | import gradio as gr |
| | import torch |
| | from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
| | import time |
| | from typing import Dict, Tuple |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| |
|
| | |
| | MAX_LENGTH = 1000 |
| | MODELS = { |
| | "cointegrated/rubert-tiny2": "Лёгкая модель (быстрая)", |
| | "s-nlp/rubert-tiny-cased-rured": "Специализированная для классификации", |
| | "ai-forever/ruBert-base": "Точная модель (медленнее)" |
| | } |
| | LABELS = { |
| | 0: "Политика", |
| | 1: "Экономика", |
| | 2: "Наука и технологии", |
| | 3: "Культура и искусство", |
| | 4: "Спорт", |
| | 5: "Здоровье и медицина", |
| | 6: "Образование", |
| | 7: "Разное" |
| | } |
| |
|
| | class TopicClassifier: |
| | def __init__(self): |
| | self.models: Dict = {} |
| | self.tokenizers: Dict = {} |
| | |
| | def load_model(self, model_name: str): |
| | """Загрузка модели по требованию""" |
| | if model_name not in self.models: |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | model_name, |
| | num_labels=len(LABELS) |
| | ) |
| | |
| | |
| | model.eval() |
| | |
| | self.models[model_name] = model |
| | self.tokenizers[model_name] = tokenizer |
| | |
| | print(f"Модель {model_name} загружена успешно") |
| | except Exception as e: |
| | raise Exception(f"Ошибка загрузки модели: {str(e)}") |
| | |
| | def predict(self, text: str, model_name: str) -> Tuple[Dict, float]: |
| | """Предсказание темы текста""" |
| | if not text.strip(): |
| | raise ValueError("Текст не может быть пустым") |
| | |
| | if len(text) > MAX_LENGTH: |
| | text = text[:MAX_LENGTH] |
| | gr.Warning(f"Текст обрезан до {MAX_LENGTH} символов") |
| | |
| | self.load_model(model_name) |
| | |
| | start_time = time.time() |
| | |
| | try: |
| | tokenizer = self.tokenizers[model_name] |
| | model = self.models[model_name] |
| | |
| | inputs = tokenizer( |
| | text, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512, |
| | padding=True |
| | ) |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| | |
| | scores = predictions[0].tolist() |
| | results = {LABELS[i]: round(score * 100, 2) for i, score in enumerate(scores)} |
| | |
| | |
| | sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) |
| | |
| | latency = round((time.time() - start_time) * 1000, 2) |
| | |
| | return sorted_results, latency |
| | |
| | except Exception as e: |
| | raise Exception(f"Ошибка при обработке: {str(e)}") |
| |
|
| | |
| | classifier = TopicClassifier() |
| |
|
| | def process_text(text: str, model_choice: str) -> Tuple[str, str, str]: |
| | """Обработка текста с выбранной моделью""" |
| | if not text.strip(): |
| | return "⚠️ Введите текст для анализа", "", "0" |
| | |
| | try: |
| | predictions, latency = classifier.predict(text, model_choice) |
| | |
| | |
| | top_topic = list(predictions.keys())[0] |
| | top_score = predictions[top_topic] |
| | |
| | result_text = f"🎯 **Основная тема:** {top_topic} ({top_score}%)\n\n" |
| | result_text += "📊 **Распределение тем:**\n" |
| | |
| | for topic, score in predictions.items(): |
| | result_text += f"• {topic}: {score}%\n" |
| | |
| | |
| | json_output = "{\n" |
| | for topic, score in predictions.items(): |
| | json_output += f' "{topic}": {score},\n' |
| | json_output = json_output.rstrip(",\n") + "\n}" |
| | |
| | return result_text, json_output, str(latency) |
| | |
| | except ValueError as e: |
| | return f"❌ {str(e)}", "", "0" |
| | except Exception as e: |
| | return f"⚠️ Ошибка: {str(e)}", "", "0" |
| |
|
| | |
| | examples = [ |
| | ["Российская экономика показала рост в третьем квартале благодаря увеличению экспорта нефти и газа."], |
| | ["Ученые создали новый материал для солнечных батарей с эффективностью 45%."], |
| | ["На чемпионате мира по футболу сборная Бразилии одержала победу со счетом 3:1."], |
| | ["В музее открылась выставка современных художников, посвященная проблемам экологии."], |
| | ["Минздрав рекомендовал новые правила вакцинации для населения старше 60 лет."] |
| | ] |
| |
|
| | |
| | with gr.Blocks(title="Классификатор тем текста", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# 🎯 Классификатор тематики текста") |
| | gr.Markdown("Определите основную тему вашего текста с помощью ИИ-моделей") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | model_selector = gr.Dropdown( |
| | choices=list(MODELS.keys()), |
| | value=list(MODELS.keys())[0], |
| | label="📋 Выберите модель", |
| | info="Каждая модель имеет разный баланс скорости и точности" |
| | ) |
| | |
| | text_input = gr.Textbox( |
| | label="📝 Введите текст для анализа", |
| | placeholder="Введите текст на русском языке...", |
| | lines=5, |
| | max_lines=10 |
| | ) |
| | |
| | process_btn = gr.Button("🔍 Анализировать текст", variant="primary") |
| | |
| | gr.Markdown("### 📋 Примеры текстов") |
| | gr.Examples( |
| | examples=examples, |
| | inputs=text_input, |
| | label="Нажмите на пример для быстрой загрузки" |
| | ) |
| | |
| | with gr.Column(scale=3): |
| | with gr.Row(): |
| | latency_display = gr.Textbox( |
| | label="⏱️ Время обработки", |
| | value="0", |
| | interactive=False |
| | ) |
| | latency_display.info = "мсек" |
| | |
| | output_text = gr.Markdown( |
| | label="📊 Результаты классификации" |
| | ) |
| | |
| | json_output = gr.Code( |
| | label="📄 JSON-формат результатов", |
| | language="json", |
| | interactive=False |
| | ) |
| | |
| | |
| | process_btn.click( |
| | fn=process_text, |
| | inputs=[text_input, model_selector], |
| | outputs=[output_text, json_output, latency_display] |
| | ) |
| | |
| | |
| | with gr.Accordion("ℹ️ Информация о моделях", open=False): |
| | gr.Markdown(""" |
| | **Доступные модели:** |
| | |
| | 1. **cointegrated/rubert-tiny2** - Быстрая и легкая модель, идеально подходит для CPU |
| | 2. **s-nlp/rubert-tiny-cased-rured** - Специализирована для тематической классификации |
| | 3. **ai-forever/ruBert-base** - Самая точная, но требует больше времени |
| | |
| | **Ограничения:** |
| | - Максимальная длина текста: 1000 символов |
| | - Только русский язык |
| | - Автоматическое определение 8 основных тем |
| | """) |
| | |
| | gr.Markdown("---") |
| | gr.Markdown("### 📌 Инструкция") |
| | gr.Markdown(""" |
| | 1. Выберите модель из списка |
| | 2. Введите или вставьте текст для анализа |
| | 3. Нажмите кнопку "Анализировать текст" |
| | 4. Получите результаты классификации и время обработки |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7860) |