File size: 2,506 Bytes
7c43e3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import json

model = DistilBertForSequenceClassification.from_pretrained('./arxiv_classifier')
tokenizer = DistilBertTokenizer.from_pretrained('./arxiv_classifier')

with open('./arxiv_classifier/index_to_category.json', 'r', encoding='utf-8') as f:
    index_to_category = json.load(f)

def predict(title, summary):
    inputs = tokenizer(title + " " + summary, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    return predictions

st.set_page_config(page_title="ArXiv Article Classifier", layout="centered")

st.title("ArXiv Article Classifier")
st.write("Введите заголовок и аннотацию статьи, чтобы получить предсказание категории.")

title = st.text_input("Заголовок статьи", placeholder="Введите заголовок статьи здесь")
summary = st.text_area("Аннотация статьи", placeholder="Введите аннотацию статьи здесь")

# Кнопка для классификации
if st.button("Классифицировать"):
    if title.strip() == "" and summary.strip() == "":
        st.error("Пожалуйста, введите заголовок или аннотацию статьи.")
    else:
        with st.spinner("Классификация..."):
            predictions = predict(title, summary)
            sorted_indices = torch.argsort(predictions[0], descending=True)
            cumulative_probability = 0.0
            st.subheader("Результаты классификации:")
            for idx in sorted_indices:
                probability = predictions[0][idx].item()
                cumulative_probability += probability
                category_name = index_to_category.get(str(idx.item()), "Unknown")
                st.write(f"Категория {category_name}: {probability:.2f}")
                if cumulative_probability >= 0.95:
                    break

st.markdown(
    """
    <style>
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        padding: 10px 24px;
        border: none;
        border-radius: 4px;
        cursor: pointer;
    }
    .stButton>button:hover {
        background-color: #45a049;
    }
    </style>
    """,
    unsafe_allow_html=True
)