Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
| import torch | |
| import json | |
| model = DistilBertForSequenceClassification.from_pretrained('./arxiv_classifier') | |
| tokenizer = DistilBertTokenizer.from_pretrained('./arxiv_classifier') | |
| with open('./arxiv_classifier/index_to_category.json', 'r', encoding='utf-8') as f: | |
| index_to_category = json.load(f) | |
| def predict(title, summary): | |
| inputs = tokenizer(title + " " + summary, return_tensors="pt", padding=True, truncation=True) | |
| outputs = model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| return predictions | |
| st.set_page_config(page_title="ArXiv Article Classifier", layout="centered") | |
| st.title("ArXiv Article Classifier") | |
| st.write("Введите заголовок и аннотацию статьи, чтобы получить предсказание категории.") | |
| title = st.text_input("Заголовок статьи", placeholder="Введите заголовок статьи здесь") | |
| summary = st.text_area("Аннотация статьи", placeholder="Введите аннотацию статьи здесь") | |
| # Кнопка для классификации | |
| if st.button("Классифицировать"): | |
| if title.strip() == "" and summary.strip() == "": | |
| st.error("Пожалуйста, введите заголовок или аннотацию статьи.") | |
| else: | |
| with st.spinner("Классификация..."): | |
| predictions = predict(title, summary) | |
| sorted_indices = torch.argsort(predictions[0], descending=True) | |
| cumulative_probability = 0.0 | |
| st.subheader("Результаты классификации:") | |
| for idx in sorted_indices: | |
| probability = predictions[0][idx].item() | |
| cumulative_probability += probability | |
| category_name = index_to_category.get(str(idx.item()), "Unknown") | |
| st.write(f"Категория {category_name}: {probability:.2f}") | |
| if cumulative_probability >= 0.95: | |
| break | |
| st.markdown( | |
| """ | |
| <style> | |
| .stButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| padding: 10px 24px; | |
| border: none; | |
| border-radius: 4px; | |
| cursor: pointer; | |
| } | |
| .stButton>button:hover { | |
| background-color: #45a049; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) |