import streamlit as st from pathlib import Path import os import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer def read_hf_token(): HF_TOKEN = os.getenv("HF_TOKEN") return HF_TOKEN @st.cache_resource(show_spinner="Loading tokenizer/model...") def load_model( model_id, num_labels, checkpoint_path=None, ignore_mismatched_sizes=True, ): hf_token = read_hf_token() tok = AutoTokenizer.from_pretrained(model_id, token=hf_token) model = AutoModelForSequenceClassification.from_pretrained(model_id, token=hf_token, num_labels=num_labels, problem_type="multi_label_classification", ignore_mismatched_sizes=ignore_mismatched_sizes) if checkpoint_path: state = torch.load(str(checkpoint_path), map_location=torch.device("cpu")) model.load_state_dict(state, strict=False) model.eval() return tok, model def predict(model, tok, title, summary, labels, threshold=0.5, max_length=512, device="cpu"): title = (title or "").strip() summary = (summary or "").strip() enc = tok( [title], [summary], padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) enc = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): logits = model(**enc).logits if logits.ndim == 3: logits = logits[:, 0, :] probs = torch.sigmoid(logits[0]).cpu() pred = [] for i, p in enumerate(probs.tolist()): if p >= threshold: pred.append((labels[i], float(p))) pred.sort(key=lambda x: x[1], reverse=True) return pred, probs.tolist() st.set_page_config(page_title="arXiv taxonomy classifier", layout="centered") st.title("arXiv taxonomy classifier") with st.expander("О проекте", expanded=False): st.markdown( """ ### Назначение Приложение предсказывает **группы таксономии arXiv** для статьи по **заголовку** и **краткому описанию (abstract)**. ### Данные и разметка - **Источник**: выгрузка arXiv из файла (поля `title`, `summary`, `tag`). - **Как получали классы**: из `tag` брали `term` (arXiv category codes, например `cs.AI`, `stat.ML`), парсили и маппили в 8 верхнеуровневых групп. - **Разметка**: multi-label — у статьи может быть несколько групп одновременно; целевой вектор `y` формировался как multi-hot по списку групп. - Добавление в выборку статей из 2020-2025 годов, а также сбалансированность классов, так как статей по Биологии кратно меньше, чем статей по Компьтерным наукам, и это также надо учитывать при обучении, что я и передаю в параметре weights. ### Метки В интерфейсе используются **8 верхнеуровневых групп**: Computer Science, Economics, Electrical Engineering and Systems Science, Mathematics, Physics, Quantitative Biology, Quantitative Finance, Statistics. ### Модели и чекпоинты - **HF model (`HF model id`)**: базовая архитектура и токенайзер скачиваются с Hugging Face. - **Локальный `.pth` (по желанию)**: если включено, поверх весов HF подгружаются **дообученные веса** (`torch.load(..., map_location='cpu')` + `model.load_state_dict(..., strict=False)`). - bert-paper-classifier-arxiv(предобученная модель на статьях arXiv) [link](https://huggingface.co/oracat/bert-paper-classifier-arxiv) - distilbert-base-cased(базовая модель берта) [link](https://huggingface.co/distilbert-base-cased) ### Обучение (в общих чертах, подробности в Readme.md) Задача — **multi-label classification** с `BCEWithLogitsLoss`; на инференсе — `sigmoid(logits)` и настраиваемый **threshold**. """ ) with st.expander("Документация", expanded=False): st.markdown( """ ### Что можно делать в приложении **В боковой панели (Settings)** - Указать **`HF model id`** — какую модель с Hugging Face загружать как основу (архитектура + токенайзер). - Настроить **`Threshold`** — порог вероятности: метки с `sigmoid(logits)` **не ниже** порога попадают в список предсказаний. - Задать **`Max length`** — максимальная длина токенизации для пары title + summary. - Выбрать **локальный чекпоинт** из папки `models/` или ввести **относительный путь** к файлу `.pth` внутри `models/` — чтобы подменить веса HF своим дообучением. - При смене других моделей с HF задать переменную окружения **`HF_TOKEN`**. **В основной форме** - Ввести **заголовок** и/или **краткое описание (abstract)** статьи (достаточно хотя бы одного поля). - Нажать **`Predict`** — получить предсказание по **8 группам** таксономии. **После предсказания** - Просмотреть таблицу **Predicted labels** — метки, прошедшие порог, с вероятностями. - Раскрыть **All probabilities** — полный JSON со всеми вероятностями по меткам (удобно подбирать `Threshold`). **Подсказка** Строка **`Device`** показывает, на чём считается инференс: `cuda` или `cpu`. """ ) if "model_id" not in st.session_state: st.session_state.model_id = "oracat/bert-paper-classifier-arxiv" if "threshold" not in st.session_state: st.session_state.threshold = 0.5 if "use_local_ckpt" not in st.session_state: st.session_state.use_local_ckpt = False with st.sidebar: st.subheader("Settings") st.session_state.model_id = st.text_input("HF model id", value=st.session_state.model_id) st.session_state.threshold = st.slider("Threshold", 0.0, 1.0, float(st.session_state.threshold), 0.01) max_length = st.slider("Max length", 64, 512, 512, 16) st.caption("Local finetuned weights (.pth) to override HF weights.") model_dir = Path(__file__).parent/"models" local_ckpts = sorted([p.name for p in model_dir.glob("*.pth")]) default_choice = local_ckpts[0] if local_ckpts else "" ckpt_choice = st.selectbox("Checkpoint (from models dir)", options=[""] + local_ckpts, index=0) ckpt_path_text = st.text_input("Or explicit checkpoint path (relative to models dir)", value="") ckpt_path = str(model_dir/"bert_paper_classifier_arxiv.pth") if ckpt_choice: ckpt_path = str(model_dir/ckpt_choice) if ckpt_path_text: ckpt_path = str(model_dir/ckpt_path_text) st.caption("Set env var `HF_TOKEN` to your Hugging Face token if you want use other models.") DEFAULT_LABELS = [ "Computer Science", "Economics", "Electrical Engineering and Systems Science", "Mathematics", "Physics", "Quantitative Biology", "Quantitative Finance", "Statistics", ] num_labels = len(DEFAULT_LABELS) device = "cuda" if torch.cuda.is_available() else "cpu" tok, model = load_model( st.session_state.model_id, num_labels=num_labels, checkpoint_path=ckpt_path, ignore_mismatched_sizes=True, ) model.to(device) st.write(f"Device: `{device}`") with st.form("predict"): title = st.text_input("Title", value="") summary = st.text_area("Summary / abstract", value="", height=220) submitted = st.form_submit_button("Predict") if submitted: if not (title.strip() or summary.strip()): st.warning("Provide at least a title or a summary.") else: preds, probs = predict( model, tok, title=title, summary=summary, labels=DEFAULT_LABELS, threshold=float(st.session_state.threshold), max_length=int(max_length), device=device, ) if not preds: st.info("No labels above threshold.") else: st.subheader("Predicted labels") st.table([{"label": lab, "prob": round(p, 4)} for lab, p in preds]) with st.expander("All probabilities"): st.json({lab: float(p) for lab, p in zip(DEFAULT_LABELS, probs)})