import streamlit as st import torch import joblib import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import hf_hub_download @st.cache_resource def load_model(): repo_id = "YakovPodlesnov/best-article-classifier" model = AutoModelForSequenceClassification.from_pretrained(repo_id) tokenizer = AutoTokenizer.from_pretrained(repo_id) label_encoder_path = hf_hub_download( repo_id=repo_id, filename="label_encoder.joblib" ) label_encoder = joblib.load(label_encoder_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) return model, tokenizer, label_encoder, device def predict(text, model, tokenizer, label_encoder, device): inputs = tokenizer( text, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0] sorted_indices = np.argsort(probs)[::-1] cumulative = 0.0 results = [] for idx in sorted_indices: if cumulative >= 0.95: break label = label_encoder.inverse_transform([idx])[0] prob = float(probs[idx]) results.append((label, prob)) cumulative += prob return results st.set_page_config(page_title="arXiv Classifier", layout="wide") st.title("Классификатор научных статей") st.markdown("Определяем тематику статьи по заголовку и аннотации (arXiv)") with st.spinner("Инициализация модели..."): try: model, tokenizer, label_encoder, device = load_model() except Exception as e: st.error(f"Ошибка загрузки модели: {str(e)}") st.stop() col1, col2 = st.columns([1, 1]) with col1: title = st.text_input("Название статьи:", placeholder="Attention Is All You Need", help="Введите полное название статьи") with col2: abstract = st.text_area("Аннотация:", height=150, placeholder="We propose a new simple network architecture...", help="Введите текст аннотации (необязательно)") if st.button("Определить категории", type="primary"): if not title: st.error("Пожалуйста, введите название статьи") else: text = title + (" " + abstract if abstract else "") with st.spinner("Анализируем текст..."): try: predictions = predict(text, model, tokenizer, label_encoder, device) except Exception as e: st.error(f"Ошибка предсказания: {str(e)}") st.stop() st.subheader(" Результаты классификации") st.markdown(""" Расшифровку названий, например для cs, можно посмотреть здесь: https://arxiv.org/archive/cs """) if not predictions: st.warning("Не удалось определить категории") else: cols = st.columns([2, 3]) with cols[0]: st.markdown("**Топ-категории:**") for label, prob in predictions: st.markdown(f"▸ {label} ({prob:.1%})") with cols[1]: labels = [p[0] for p in predictions] probs = [p[1] for p in predictions] st.bar_chart( dict(zip(labels, probs)), use_container_width=True, color="#FF4B4B" ) st.markdown("---") st.markdown(""" ### Как это работает? 1. Введите **название** научной статьи 2. При необходимости добавьте **аннотацию** 3. Нажмите кнопку для классификации статьи 4. Получите предсказанные категории arXiv **Модель: scibert_scivocab_uncased** """) if __name__ == "__main__": pass