import streamlit as st import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import json import os BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_DIR = os.path.join(BASE_DIR, "arxiv_dir") st.set_page_config( page_title="Arxiv Classifier", page_icon="🚀", layout="wide", initial_sidebar_state="collapsed" ) st.title("Arxiv Classifier") @st.cache_resource def load_model(): model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) with open(os.path.join(MODEL_DIR, "id2tag.json")) as f: id2tag = {int(k): v for k, v in json.load(f).items()} with open(os.path.join(MODEL_DIR, "tag2name.json")) as f: tag2name = json.load(f) model.eval() return model, tokenizer, id2tag, tag2name model, tokenizer, id2tag, tag2name = load_model() def predict_top95(title, summary=None): text = title if summary: text += " [SEP] " + summary tokens = tokenizer(text, truncation=True, max_length=512, return_tensors="pt") tokens = {k: v.to(model.device) for k, v in tokens.items()} model.eval() with torch.no_grad(): logits = model(**tokens).logits probs = torch.softmax(logits, dim=-1)[0] sorted_probs, sorted_idx = probs.sort(descending=True) cumsum = sorted_probs.cumsum(dim=-1) mask = (cumsum - sorted_probs) < 0.95 results = [] for prob, idx in zip(sorted_probs[mask], sorted_idx[mask]): results.append((id2tag[idx.item()], prob.item())) return results def colored_bar(prob): if prob > 0.4: color = "#2ecc71" elif prob > 0.15: color = "#f39c12" else: color = "#e74c3c" st.markdown(f"""
""", unsafe_allow_html=True) col1, col2 = st.columns([1, 1]) with col1: name = st.text_input("Название статьи") abstract = st.text_area("Abstract", height=200) clicked = st.button("Классифицировать") with col2: if clicked: if not name and not abstract: st.warning("Введите название или abstract") else: results = predict_top95(name, abstract if abstract else None) st.markdown(f"### Результаты — {len(results)} {'класс' if len(results) == 1 else 'класса' if len(results) < 5 else 'классов'}") visible = results[:10] hidden = results[10:] for tag, prob in visible: label = tag2name.get(tag, tag) st.markdown(f"**{label}** `{tag}`") colored_bar(prob) st.caption(f"{prob:.1%}") if hidden: with st.expander(f"Показать ещё {len(hidden)}"): for tag, prob in hidden: label = tag2name.get(tag, tag) st.markdown(f"**{label}** `{tag}`") colored_bar(prob) st.caption(f"{prob:.1%}")