import gradio as gr import torch from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification import time from typing import Dict, Tuple import warnings warnings.filterwarnings('ignore') # Конфигурация MAX_LENGTH = 1000 MODELS = { "cointegrated/rubert-tiny2": "Лёгкая модель (быстрая)", "s-nlp/rubert-tiny-cased-rured": "Специализированная для классификации", "ai-forever/ruBert-base": "Точная модель (медленнее)" } LABELS = { 0: "Политика", 1: "Экономика", 2: "Наука и технологии", 3: "Культура и искусство", 4: "Спорт", 5: "Здоровье и медицина", 6: "Образование", 7: "Разное" } class TopicClassifier: def __init__(self): self.models: Dict = {} self.tokenizers: Dict = {} def load_model(self, model_name: str): """Загрузка модели по требованию""" if model_name not in self.models: try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=len(LABELS) ) # Настройка для CPU model.eval() self.models[model_name] = model self.tokenizers[model_name] = tokenizer print(f"Модель {model_name} загружена успешно") except Exception as e: raise Exception(f"Ошибка загрузки модели: {str(e)}") def predict(self, text: str, model_name: str) -> Tuple[Dict, float]: """Предсказание темы текста""" if not text.strip(): raise ValueError("Текст не может быть пустым") if len(text) > MAX_LENGTH: text = text[:MAX_LENGTH] gr.Warning(f"Текст обрезан до {MAX_LENGTH} символов") self.load_model(model_name) start_time = time.time() try: tokenizer = self.tokenizers[model_name] model = self.models[model_name] inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ) with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) scores = predictions[0].tolist() results = {LABELS[i]: round(score * 100, 2) for i, score in enumerate(scores)} # Сортировка по уверенности sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) latency = round((time.time() - start_time) * 1000, 2) # в мс return sorted_results, latency except Exception as e: raise Exception(f"Ошибка при обработке: {str(e)}") # Инициализация классификатора classifier = TopicClassifier() def process_text(text: str, model_choice: str) -> Tuple[str, str, str]: """Обработка текста с выбранной моделью""" if not text.strip(): return "⚠️ Введите текст для анализа", "", "0" try: predictions, latency = classifier.predict(text, model_choice) # Форматирование результатов top_topic = list(predictions.keys())[0] top_score = predictions[top_topic] result_text = f"🎯 **Основная тема:** {top_topic} ({top_score}%)\n\n" result_text += "📊 **Распределение тем:**\n" for topic, score in predictions.items(): result_text += f"• {topic}: {score}%\n" # Подготовка JSON для отладки json_output = "{\n" for topic, score in predictions.items(): json_output += f' "{topic}": {score},\n' json_output = json_output.rstrip(",\n") + "\n}" return result_text, json_output, str(latency) except ValueError as e: return f"❌ {str(e)}", "", "0" except Exception as e: return f"⚠️ Ошибка: {str(e)}", "", "0" # Примеры текстов examples = [ ["Российская экономика показала рост в третьем квартале благодаря увеличению экспорта нефти и газа."], ["Ученые создали новый материал для солнечных батарей с эффективностью 45%."], ["На чемпионате мира по футболу сборная Бразилии одержала победу со счетом 3:1."], ["В музее открылась выставка современных художников, посвященная проблемам экологии."], ["Минздрав рекомендовал новые правила вакцинации для населения старше 60 лет."] ] # Создание интерфейса with gr.Blocks(title="Классификатор тем текста", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎯 Классификатор тематики текста") gr.Markdown("Определите основную тему вашего текста с помощью ИИ-моделей") with gr.Row(): with gr.Column(scale=2): model_selector = gr.Dropdown( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="📋 Выберите модель", info="Каждая модель имеет разный баланс скорости и точности" ) text_input = gr.Textbox( label="📝 Введите текст для анализа", placeholder="Введите текст на русском языке...", lines=5, max_lines=10 ) process_btn = gr.Button("🔍 Анализировать текст", variant="primary") gr.Markdown("### 📋 Примеры текстов") gr.Examples( examples=examples, inputs=text_input, label="Нажмите на пример для быстрой загрузки" ) with gr.Column(scale=3): with gr.Row(): latency_display = gr.Textbox( label="⏱️ Время обработки", value="0", interactive=False ) latency_display.info = "мсек" output_text = gr.Markdown( label="📊 Результаты классификации" ) json_output = gr.Code( label="📄 JSON-формат результатов", language="json", interactive=False ) # Обработка событий process_btn.click( fn=process_text, inputs=[text_input, model_selector], outputs=[output_text, json_output, latency_display] ) # Дополнительная информация with gr.Accordion("ℹ️ Информация о моделях", open=False): gr.Markdown(""" **Доступные модели:** 1. **cointegrated/rubert-tiny2** - Быстрая и легкая модель, идеально подходит для CPU 2. **s-nlp/rubert-tiny-cased-rured** - Специализирована для тематической классификации 3. **ai-forever/ruBert-base** - Самая точная, но требует больше времени **Ограничения:** - Максимальная длина текста: 1000 символов - Только русский язык - Автоматическое определение 8 основных тем """) gr.Markdown("---") gr.Markdown("### 📌 Инструкция") gr.Markdown(""" 1. Выберите модель из списка 2. Введите или вставьте текст для анализа 3. Нажмите кнопку "Анализировать текст" 4. Получите результаты классификации и время обработки """) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)