Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import joblib | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| def load_model(): | |
| repo_id = "YakovPodlesnov/best-article-classifier" | |
| model = AutoModelForSequenceClassification.from_pretrained(repo_id) | |
| tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
| label_encoder_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="label_encoder.joblib" | |
| ) | |
| label_encoder = joblib.load(label_encoder_path) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| return model, tokenizer, label_encoder, device | |
| def predict(text, model, tokenizer, label_encoder, device): | |
| inputs = tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
| sorted_indices = np.argsort(probs)[::-1] | |
| cumulative = 0.0 | |
| results = [] | |
| for idx in sorted_indices: | |
| if cumulative >= 0.95: | |
| break | |
| label = label_encoder.inverse_transform([idx])[0] | |
| prob = float(probs[idx]) | |
| results.append((label, prob)) | |
| cumulative += prob | |
| return results | |
| st.set_page_config(page_title="arXiv Classifier", layout="wide") | |
| st.title("Классификатор научных статей") | |
| st.markdown("Определяем тематику статьи по заголовку и аннотации (arXiv)") | |
| with st.spinner("Инициализация модели..."): | |
| try: | |
| model, tokenizer, label_encoder, device = load_model() | |
| except Exception as e: | |
| st.error(f"Ошибка загрузки модели: {str(e)}") | |
| st.stop() | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| title = st.text_input("Название статьи:", | |
| placeholder="Attention Is All You Need", | |
| help="Введите полное название статьи") | |
| with col2: | |
| abstract = st.text_area("Аннотация:", | |
| height=150, | |
| placeholder="We propose a new simple network architecture...", | |
| help="Введите текст аннотации (необязательно)") | |
| if st.button("Определить категории", type="primary"): | |
| if not title: | |
| st.error("Пожалуйста, введите название статьи") | |
| else: | |
| text = title + (" " + abstract if abstract else "") | |
| with st.spinner("Анализируем текст..."): | |
| try: | |
| predictions = predict(text, model, tokenizer, label_encoder, device) | |
| except Exception as e: | |
| st.error(f"Ошибка предсказания: {str(e)}") | |
| st.stop() | |
| st.subheader(" Результаты классификации") | |
| st.markdown(""" | |
| Расшифровку названий, например для cs, можно посмотреть здесь: https://arxiv.org/archive/cs | |
| """) | |
| if not predictions: | |
| st.warning("Не удалось определить категории") | |
| else: | |
| cols = st.columns([2, 3]) | |
| with cols[0]: | |
| st.markdown("**Топ-категории:**") | |
| for label, prob in predictions: | |
| st.markdown(f"▸ {label} ({prob:.1%})") | |
| with cols[1]: | |
| labels = [p[0] for p in predictions] | |
| probs = [p[1] for p in predictions] | |
| st.bar_chart( | |
| dict(zip(labels, probs)), | |
| use_container_width=True, | |
| color="#FF4B4B" | |
| ) | |
| st.markdown("---") | |
| st.markdown(""" | |
| ### Как это работает? | |
| 1. Введите **название** научной статьи | |
| 2. При необходимости добавьте **аннотацию** | |
| 3. Нажмите кнопку для классификации статьи | |
| 4. Получите предсказанные категории arXiv | |
| **Модель: scibert_scivocab_uncased** | |
| """) | |
| if __name__ == "__main__": | |
| pass |