| import streamlit as st |
| from pathlib import Path |
| import os |
|
|
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
| def read_hf_token(): |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| return HF_TOKEN |
|
|
| @st.cache_resource(show_spinner="Loading tokenizer/model...") |
| def load_model( |
| model_id, |
| num_labels, |
| checkpoint_path=None, |
| ignore_mismatched_sizes=True, |
| ): |
| hf_token = read_hf_token() |
|
|
| tok = AutoTokenizer.from_pretrained(model_id, token=hf_token) |
| model = AutoModelForSequenceClassification.from_pretrained(model_id, token=hf_token, |
| num_labels=num_labels, problem_type="multi_label_classification", |
| ignore_mismatched_sizes=ignore_mismatched_sizes) |
|
|
| if checkpoint_path: |
| state = torch.load(str(checkpoint_path), map_location=torch.device("cpu")) |
| model.load_state_dict(state, strict=False) |
| model.eval() |
| return tok, model |
|
|
|
|
| def predict(model, tok, title, summary, labels, threshold=0.5, max_length=512, device="cpu"): |
| title = (title or "").strip() |
| summary = (summary or "").strip() |
|
|
| enc = tok( |
| [title], |
| [summary], |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ) |
|
|
| enc = {k: v.to(device) for k, v in enc.items()} |
|
|
| with torch.no_grad(): |
| logits = model(**enc).logits |
| if logits.ndim == 3: |
| logits = logits[:, 0, :] |
| probs = torch.sigmoid(logits[0]).cpu() |
|
|
| pred = [] |
| for i, p in enumerate(probs.tolist()): |
| if p >= threshold: |
| pred.append((labels[i], float(p))) |
| pred.sort(key=lambda x: x[1], reverse=True) |
| return pred, probs.tolist() |
|
|
|
|
| st.set_page_config(page_title="arXiv taxonomy classifier", layout="centered") |
| st.title("arXiv taxonomy classifier") |
|
|
| with st.expander("О проекте", expanded=False): |
| st.markdown( |
| """ |
| ### Назначение |
| Приложение предсказывает **группы таксономии arXiv** для статьи по **заголовку** и **краткому описанию (abstract)**. |
| |
| ### Данные и разметка |
| - **Источник**: выгрузка arXiv из файла (поля `title`, `summary`, `tag`). |
| - **Как получали классы**: из `tag` брали `term` (arXiv category codes, например `cs.AI`, `stat.ML`), парсили и |
| маппили в 8 верхнеуровневых групп. |
| - **Разметка**: multi-label — у статьи может быть несколько групп одновременно; целевой вектор `y` формировался как |
| multi-hot по списку групп. |
| - Добавление в выборку статей из 2020-2025 годов, а также сбалансированность классов, так как статей по Биологии кратно меньше, |
| чем статей по Компьтерным наукам, и это также надо учитывать при обучении, что я и передаю в параметре weights. |
| |
| ### Метки |
| В интерфейсе используются **8 верхнеуровневых групп**: |
| Computer Science, Economics, Electrical Engineering and Systems Science, Mathematics, Physics, Quantitative Biology, Quantitative Finance, Statistics. |
| |
| ### Модели и чекпоинты |
| - **HF model (`HF model id`)**: базовая архитектура и токенайзер скачиваются с Hugging Face. |
| - **Локальный `.pth` (по желанию)**: если включено, поверх весов HF подгружаются **дообученные веса** |
| (`torch.load(..., map_location='cpu')` + `model.load_state_dict(..., strict=False)`). |
| - bert-paper-classifier-arxiv(предобученная модель на статьях arXiv) [link](https://huggingface.co/oracat/bert-paper-classifier-arxiv) |
| - distilbert-base-cased(базовая модель берта) [link](https://huggingface.co/distilbert-base-cased) |
| |
| ### Обучение (в общих чертах, подробности в Readme.md) |
| Задача — **multi-label classification** с `BCEWithLogitsLoss`; на инференсе — `sigmoid(logits)` и настраиваемый **threshold**. |
| """ |
| ) |
|
|
| with st.expander("Документация", expanded=False): |
| st.markdown( |
| """ |
| ### Что можно делать в приложении |
| |
| **В боковой панели (Settings)** |
| - Указать **`HF model id`** — какую модель с Hugging Face загружать как основу (архитектура + токенайзер). |
| - Настроить **`Threshold`** — порог вероятности: метки с `sigmoid(logits)` **не ниже** порога попадают в список предсказаний. |
| - Задать **`Max length`** — максимальная длина токенизации для пары title + summary. |
| - Выбрать **локальный чекпоинт** из папки `models/` или ввести **относительный путь** к файлу `.pth` внутри `models/` — чтобы подменить веса HF своим дообучением. |
| - При смене других моделей с HF задать переменную окружения **`HF_TOKEN`**. |
| |
| **В основной форме** |
| - Ввести **заголовок** и/или **краткое описание (abstract)** статьи (достаточно хотя бы одного поля). |
| - Нажать **`Predict`** — получить предсказание по **8 группам** таксономии. |
| |
| **После предсказания** |
| - Просмотреть таблицу **Predicted labels** — метки, прошедшие порог, с вероятностями. |
| - Раскрыть **All probabilities** — полный JSON со всеми вероятностями по меткам (удобно подбирать `Threshold`). |
| |
| **Подсказка** |
| Строка **`Device`** показывает, на чём считается инференс: `cuda` или `cpu`. |
| """ |
| ) |
|
|
|
|
| if "model_id" not in st.session_state: |
| st.session_state.model_id = "oracat/bert-paper-classifier-arxiv" |
| if "threshold" not in st.session_state: |
| st.session_state.threshold = 0.5 |
| if "use_local_ckpt" not in st.session_state: |
| st.session_state.use_local_ckpt = False |
|
|
| with st.sidebar: |
| st.subheader("Settings") |
| st.session_state.model_id = st.text_input("HF model id", value=st.session_state.model_id) |
| st.session_state.threshold = st.slider("Threshold", 0.0, 1.0, float(st.session_state.threshold), 0.01) |
| max_length = st.slider("Max length", 64, 512, 512, 16) |
|
|
| st.caption("Local finetuned weights (.pth) to override HF weights.") |
|
|
| model_dir = Path(__file__).parent/"models" |
| local_ckpts = sorted([p.name for p in model_dir.glob("*.pth")]) |
| default_choice = local_ckpts[0] if local_ckpts else "" |
| ckpt_choice = st.selectbox("Checkpoint (from models dir)", options=[""] + local_ckpts, index=0) |
| ckpt_path_text = st.text_input("Or explicit checkpoint path (relative to models dir)", value="") |
|
|
| ckpt_path = str(model_dir/"bert_paper_classifier_arxiv.pth") |
| |
| if ckpt_choice: |
| ckpt_path = str(model_dir/ckpt_choice) |
|
|
| if ckpt_path_text: |
| ckpt_path = str(model_dir/ckpt_path_text) |
|
|
| st.caption("Set env var `HF_TOKEN` to your Hugging Face token if you want use other models.") |
|
|
|
|
| DEFAULT_LABELS = [ |
| "Computer Science", |
| "Economics", |
| "Electrical Engineering and Systems Science", |
| "Mathematics", |
| "Physics", |
| "Quantitative Biology", |
| "Quantitative Finance", |
| "Statistics", |
| ] |
|
|
| num_labels = len(DEFAULT_LABELS) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| tok, model = load_model( |
| st.session_state.model_id, |
| num_labels=num_labels, |
| checkpoint_path=ckpt_path, |
| ignore_mismatched_sizes=True, |
| ) |
| model.to(device) |
|
|
| st.write(f"Device: `{device}`") |
|
|
| with st.form("predict"): |
| title = st.text_input("Title", value="") |
| summary = st.text_area("Summary / abstract", value="", height=220) |
| submitted = st.form_submit_button("Predict") |
|
|
| if submitted: |
| if not (title.strip() or summary.strip()): |
| st.warning("Provide at least a title or a summary.") |
| else: |
| preds, probs = predict( |
| model, |
| tok, |
| title=title, |
| summary=summary, |
| labels=DEFAULT_LABELS, |
| threshold=float(st.session_state.threshold), |
| max_length=int(max_length), |
| device=device, |
| ) |
|
|
| if not preds: |
| st.info("No labels above threshold.") |
| else: |
| st.subheader("Predicted labels") |
| st.table([{"label": lab, "prob": round(p, 4)} for lab, p in preds]) |
|
|
| with st.expander("All probabilities"): |
| st.json({lab: float(p) for lab, p in zip(DEFAULT_LABELS, probs)}) |
|
|