task1 / app.py
PMI25's picture
Update app.py
359319d verified
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import time
from typing import Dict, Tuple
import warnings
warnings.filterwarnings('ignore')
# Конфигурация
MAX_LENGTH = 1000
MODELS = {
"cointegrated/rubert-tiny2": "Лёгкая модель (быстрая)",
"s-nlp/rubert-tiny-cased-rured": "Специализированная для классификации",
"ai-forever/ruBert-base": "Точная модель (медленнее)"
}
LABELS = {
0: "Политика",
1: "Экономика",
2: "Наука и технологии",
3: "Культура и искусство",
4: "Спорт",
5: "Здоровье и медицина",
6: "Образование",
7: "Разное"
}
class TopicClassifier:
def __init__(self):
self.models: Dict = {}
self.tokenizers: Dict = {}
def load_model(self, model_name: str):
"""Загрузка модели по требованию"""
if model_name not in self.models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=len(LABELS)
)
# Настройка для CPU
model.eval()
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
print(f"Модель {model_name} загружена успешно")
except Exception as e:
raise Exception(f"Ошибка загрузки модели: {str(e)}")
def predict(self, text: str, model_name: str) -> Tuple[Dict, float]:
"""Предсказание темы текста"""
if not text.strip():
raise ValueError("Текст не может быть пустым")
if len(text) > MAX_LENGTH:
text = text[:MAX_LENGTH]
gr.Warning(f"Текст обрезан до {MAX_LENGTH} символов")
self.load_model(model_name)
start_time = time.time()
try:
tokenizer = self.tokenizers[model_name]
model = self.models[model_name]
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
scores = predictions[0].tolist()
results = {LABELS[i]: round(score * 100, 2) for i, score in enumerate(scores)}
# Сортировка по уверенности
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
latency = round((time.time() - start_time) * 1000, 2) # в мс
return sorted_results, latency
except Exception as e:
raise Exception(f"Ошибка при обработке: {str(e)}")
# Инициализация классификатора
classifier = TopicClassifier()
def process_text(text: str, model_choice: str) -> Tuple[str, str, str]:
"""Обработка текста с выбранной моделью"""
if not text.strip():
return "⚠️ Введите текст для анализа", "", "0"
try:
predictions, latency = classifier.predict(text, model_choice)
# Форматирование результатов
top_topic = list(predictions.keys())[0]
top_score = predictions[top_topic]
result_text = f"🎯 **Основная тема:** {top_topic} ({top_score}%)\n\n"
result_text += "📊 **Распределение тем:**\n"
for topic, score in predictions.items():
result_text += f"• {topic}: {score}%\n"
# Подготовка JSON для отладки
json_output = "{\n"
for topic, score in predictions.items():
json_output += f' "{topic}": {score},\n'
json_output = json_output.rstrip(",\n") + "\n}"
return result_text, json_output, str(latency)
except ValueError as e:
return f"❌ {str(e)}", "", "0"
except Exception as e:
return f"⚠️ Ошибка: {str(e)}", "", "0"
# Примеры текстов
examples = [
["Российская экономика показала рост в третьем квартале благодаря увеличению экспорта нефти и газа."],
["Ученые создали новый материал для солнечных батарей с эффективностью 45%."],
["На чемпионате мира по футболу сборная Бразилии одержала победу со счетом 3:1."],
["В музее открылась выставка современных художников, посвященная проблемам экологии."],
["Минздрав рекомендовал новые правила вакцинации для населения старше 60 лет."]
]
# Создание интерфейса
with gr.Blocks(title="Классификатор тем текста", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 Классификатор тематики текста")
gr.Markdown("Определите основную тему вашего текста с помощью ИИ-моделей")
with gr.Row():
with gr.Column(scale=2):
model_selector = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="📋 Выберите модель",
info="Каждая модель имеет разный баланс скорости и точности"
)
text_input = gr.Textbox(
label="📝 Введите текст для анализа",
placeholder="Введите текст на русском языке...",
lines=5,
max_lines=10
)
process_btn = gr.Button("🔍 Анализировать текст", variant="primary")
gr.Markdown("### 📋 Примеры текстов")
gr.Examples(
examples=examples,
inputs=text_input,
label="Нажмите на пример для быстрой загрузки"
)
with gr.Column(scale=3):
with gr.Row():
latency_display = gr.Textbox(
label="⏱️ Время обработки",
value="0",
interactive=False
)
latency_display.info = "мсек"
output_text = gr.Markdown(
label="📊 Результаты классификации"
)
json_output = gr.Code(
label="📄 JSON-формат результатов",
language="json",
interactive=False
)
# Обработка событий
process_btn.click(
fn=process_text,
inputs=[text_input, model_selector],
outputs=[output_text, json_output, latency_display]
)
# Дополнительная информация
with gr.Accordion("ℹ️ Информация о моделях", open=False):
gr.Markdown("""
**Доступные модели:**
1. **cointegrated/rubert-tiny2** - Быстрая и легкая модель, идеально подходит для CPU
2. **s-nlp/rubert-tiny-cased-rured** - Специализирована для тематической классификации
3. **ai-forever/ruBert-base** - Самая точная, но требует больше времени
**Ограничения:**
- Максимальная длина текста: 1000 символов
- Только русский язык
- Автоматическое определение 8 основных тем
""")
gr.Markdown("---")
gr.Markdown("### 📌 Инструкция")
gr.Markdown("""
1. Выберите модель из списка
2. Введите или вставьте текст для анализа
3. Нажмите кнопку "Анализировать текст"
4. Получите результаты классификации и время обработки
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)