Spaces:
Sleeping
Sleeping
File size: 2,506 Bytes
7c43e3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import json
model = DistilBertForSequenceClassification.from_pretrained('./arxiv_classifier')
tokenizer = DistilBertTokenizer.from_pretrained('./arxiv_classifier')
with open('./arxiv_classifier/index_to_category.json', 'r', encoding='utf-8') as f:
index_to_category = json.load(f)
def predict(title, summary):
inputs = tokenizer(title + " " + summary, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
return predictions
st.set_page_config(page_title="ArXiv Article Classifier", layout="centered")
st.title("ArXiv Article Classifier")
st.write("Введите заголовок и аннотацию статьи, чтобы получить предсказание категории.")
title = st.text_input("Заголовок статьи", placeholder="Введите заголовок статьи здесь")
summary = st.text_area("Аннотация статьи", placeholder="Введите аннотацию статьи здесь")
# Кнопка для классификации
if st.button("Классифицировать"):
if title.strip() == "" and summary.strip() == "":
st.error("Пожалуйста, введите заголовок или аннотацию статьи.")
else:
with st.spinner("Классификация..."):
predictions = predict(title, summary)
sorted_indices = torch.argsort(predictions[0], descending=True)
cumulative_probability = 0.0
st.subheader("Результаты классификации:")
for idx in sorted_indices:
probability = predictions[0][idx].item()
cumulative_probability += probability
category_name = index_to_category.get(str(idx.item()), "Unknown")
st.write(f"Категория {category_name}: {probability:.2f}")
if cumulative_probability >= 0.95:
break
st.markdown(
"""
<style>
.stButton>button {
background-color: #4CAF50;
color: white;
padding: 10px 24px;
border: none;
border-radius: 4px;
cursor: pointer;
}
.stButton>button:hover {
background-color: #45a049;
}
</style>
""",
unsafe_allow_html=True
) |