| import streamlit as st |
| import torch |
| import numpy as np |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| @st.cache_resource |
| def load_model(): |
| path_model = "aaaleksandrasimonova/arxiv_article_model1" |
| model = AutoModelForSequenceClassification.from_pretrained(path_model) |
| tokenizer = AutoTokenizer.from_pretrained(path_model) |
| return model, tokenizer |
|
|
| model, tokenizer = load_model() |
| model.eval() |
|
|
| id2label = model.config.id2label |
|
|
| def predict_proba(title, abstract): |
| inputs = tokenizer(title, |
| abstract, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=256) |
|
|
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| outputs = model(**inputs).logits.cpu() |
|
|
| probs = torch.softmax(outputs, dim=1).numpy()[0] |
| return probs |
|
|
|
|
| def top_95(probs): |
| sorted_indices = np.argsort(probs)[::-1] |
|
|
| total = 0 |
| result = [] |
|
|
| for idx in sorted_indices: |
| total += probs[idx] |
| result.append((id2label[idx], float(probs[idx]))) |
| if total >= 0.95: |
| break |
|
|
| return result |
|
|
| class_name = { |
| "cs": "Computer Science", |
| "econ": "Economics", |
| "eess": "Electrical Engineering and Systems Science", |
| "math": "Mathematics", |
| "physics": "Physics", |
| "q-bio": "Quantitative Biology", |
| "q-fin": "Quantitative Finance", |
| "stat": "Statistics", |
| } |
|
|
| tags_str = "\n\n ".join([f"✅ {key} - {value}" for key, value in class_name.items()]) |
|
|
|
|
| st.title("📚 Arxiv Article Classifier") |
|
|
| st.markdown( |
| "Введите название статьи и abstrcat (можно оставить пустым).\n\nПолучите пресказание тематики статьи (top-95%)" |
| ) |
|
|
| st.markdown( |
| f"Тематики:\n\n {tags_str}" |
| ) |
|
|
|
|
| title = st.text_input("📝 Title статьи") |
|
|
| abstract = st.text_area("📄 Abstract статьи (можно оставить пустым)") |
|
|
| if st.button("🔍 Классифицировать"): |
| |
| if title.strip() == "": |
| st.error("Пожалуйста, введите название статьи.") |
| else: |
| probs = predict_proba(title, abstract) |
| results = top_95(probs) |
|
|
| if len(results) >= 6: |
| st.error("Возможно, ваша статья не соответствует ни одной из доступных категорий.") |
|
|
| st.subheader("📊 Результат:") |
|
|
| for tag, prob in results: |
| st.write(f"**{tag}. {class_name[tag]}** — {prob:.2%}") |
|
|
| total_prob = sum([p for _, p in results]) |
| st.caption(f"Суммарная вероятность: {total_prob:.2%}") |