Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import json | |
| import os | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "arxiv_dir") | |
| st.set_page_config( | |
| page_title="Arxiv Classifier", | |
| page_icon="🚀", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| st.title("Arxiv Classifier") | |
| def load_model(): | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
| with open(os.path.join(MODEL_DIR, "id2tag.json")) as f: | |
| id2tag = {int(k): v for k, v in json.load(f).items()} | |
| with open(os.path.join(MODEL_DIR, "tag2name.json")) as f: | |
| tag2name = json.load(f) | |
| model.eval() | |
| return model, tokenizer, id2tag, tag2name | |
| model, tokenizer, id2tag, tag2name = load_model() | |
| def predict_top95(title, summary=None): | |
| text = title | |
| if summary: | |
| text += " [SEP] " + summary | |
| tokens = tokenizer(text, truncation=True, max_length=512, return_tensors="pt") | |
| tokens = {k: v.to(model.device) for k, v in tokens.items()} | |
| model.eval() | |
| with torch.no_grad(): | |
| logits = model(**tokens).logits | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| sorted_probs, sorted_idx = probs.sort(descending=True) | |
| cumsum = sorted_probs.cumsum(dim=-1) | |
| mask = (cumsum - sorted_probs) < 0.95 | |
| results = [] | |
| for prob, idx in zip(sorted_probs[mask], sorted_idx[mask]): | |
| results.append((id2tag[idx.item()], prob.item())) | |
| return results | |
| def colored_bar(prob): | |
| if prob > 0.4: | |
| color = "#2ecc71" | |
| elif prob > 0.15: | |
| color = "#f39c12" | |
| else: | |
| color = "#e74c3c" | |
| st.markdown(f""" | |
| <div style="background:#e0e0e0;border-radius:4px;height:10px;margin:4px 0"> | |
| <div style="background:{color};width:{prob*100:.1f}%;height:100%;border-radius:4px"></div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| name = st.text_input("Название статьи") | |
| abstract = st.text_area("Abstract", height=200) | |
| clicked = st.button("Классифицировать") | |
| with col2: | |
| if clicked: | |
| if not name and not abstract: | |
| st.warning("Введите название или abstract") | |
| else: | |
| results = predict_top95(name, abstract if abstract else None) | |
| st.markdown(f"### Результаты — {len(results)} {'класс' if len(results) == 1 else 'класса' if len(results) < 5 else 'классов'}") | |
| visible = results[:10] | |
| hidden = results[10:] | |
| for tag, prob in visible: | |
| label = tag2name.get(tag, tag) | |
| st.markdown(f"**{label}** `{tag}`") | |
| colored_bar(prob) | |
| st.caption(f"{prob:.1%}") | |
| if hidden: | |
| with st.expander(f"Показать ещё {len(hidden)}"): | |
| for tag, prob in hidden: | |
| label = tag2name.get(tag, tag) | |
| st.markdown(f"**{label}** `{tag}`") | |
| colored_bar(prob) | |
| st.caption(f"{prob:.1%}") | |