Gennadiy Polyakov
hw4
68471fc
import streamlit as st
from pathlib import Path
import os
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def read_hf_token():
HF_TOKEN = os.getenv("HF_TOKEN")
return HF_TOKEN
@st.cache_resource(show_spinner="Loading tokenizer/model...")
def load_model(
model_id,
num_labels,
checkpoint_path=None,
ignore_mismatched_sizes=True,
):
hf_token = read_hf_token()
tok = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForSequenceClassification.from_pretrained(model_id, token=hf_token,
num_labels=num_labels, problem_type="multi_label_classification",
ignore_mismatched_sizes=ignore_mismatched_sizes)
if checkpoint_path:
state = torch.load(str(checkpoint_path), map_location=torch.device("cpu"))
model.load_state_dict(state, strict=False)
model.eval()
return tok, model
def predict(model, tok, title, summary, labels, threshold=0.5, max_length=512, device="cpu"):
title = (title or "").strip()
summary = (summary or "").strip()
enc = tok(
[title],
[summary],
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
logits = model(**enc).logits
if logits.ndim == 3:
logits = logits[:, 0, :]
probs = torch.sigmoid(logits[0]).cpu()
pred = []
for i, p in enumerate(probs.tolist()):
if p >= threshold:
pred.append((labels[i], float(p)))
pred.sort(key=lambda x: x[1], reverse=True)
return pred, probs.tolist()
st.set_page_config(page_title="arXiv taxonomy classifier", layout="centered")
st.title("arXiv taxonomy classifier")
with st.expander("О проекте", expanded=False):
st.markdown(
"""
### Назначение
Приложение предсказывает **группы таксономии arXiv** для статьи по **заголовку** и **краткому описанию (abstract)**.
### Данные и разметка
- **Источник**: выгрузка arXiv из файла (поля `title`, `summary`, `tag`).
- **Как получали классы**: из `tag` брали `term` (arXiv category codes, например `cs.AI`, `stat.ML`), парсили и
маппили в 8 верхнеуровневых групп.
- **Разметка**: multi-label — у статьи может быть несколько групп одновременно; целевой вектор `y` формировался как
multi-hot по списку групп.
- Добавление в выборку статей из 2020-2025 годов, а также сбалансированность классов, так как статей по Биологии кратно меньше,
чем статей по Компьтерным наукам, и это также надо учитывать при обучении, что я и передаю в параметре weights.
### Метки
В интерфейсе используются **8 верхнеуровневых групп**:
Computer Science, Economics, Electrical Engineering and Systems Science, Mathematics, Physics, Quantitative Biology, Quantitative Finance, Statistics.
### Модели и чекпоинты
- **HF model (`HF model id`)**: базовая архитектура и токенайзер скачиваются с Hugging Face.
- **Локальный `.pth` (по желанию)**: если включено, поверх весов HF подгружаются **дообученные веса**
(`torch.load(..., map_location='cpu')` + `model.load_state_dict(..., strict=False)`).
- bert-paper-classifier-arxiv(предобученная модель на статьях arXiv) [link](https://huggingface.co/oracat/bert-paper-classifier-arxiv)
- distilbert-base-cased(базовая модель берта) [link](https://huggingface.co/distilbert-base-cased)
### Обучение (в общих чертах, подробности в Readme.md)
Задача — **multi-label classification** с `BCEWithLogitsLoss`; на инференсе — `sigmoid(logits)` и настраиваемый **threshold**.
"""
)
with st.expander("Документация", expanded=False):
st.markdown(
"""
### Что можно делать в приложении
**В боковой панели (Settings)**
- Указать **`HF model id`** — какую модель с Hugging Face загружать как основу (архитектура + токенайзер).
- Настроить **`Threshold`** — порог вероятности: метки с `sigmoid(logits)` **не ниже** порога попадают в список предсказаний.
- Задать **`Max length`** — максимальная длина токенизации для пары title + summary.
- Выбрать **локальный чекпоинт** из папки `models/` или ввести **относительный путь** к файлу `.pth` внутри `models/` — чтобы подменить веса HF своим дообучением.
- При смене других моделей с HF задать переменную окружения **`HF_TOKEN`**.
**В основной форме**
- Ввести **заголовок** и/или **краткое описание (abstract)** статьи (достаточно хотя бы одного поля).
- Нажать **`Predict`** — получить предсказание по **8 группам** таксономии.
**После предсказания**
- Просмотреть таблицу **Predicted labels** — метки, прошедшие порог, с вероятностями.
- Раскрыть **All probabilities** — полный JSON со всеми вероятностями по меткам (удобно подбирать `Threshold`).
**Подсказка**
Строка **`Device`** показывает, на чём считается инференс: `cuda` или `cpu`.
"""
)
if "model_id" not in st.session_state:
st.session_state.model_id = "oracat/bert-paper-classifier-arxiv"
if "threshold" not in st.session_state:
st.session_state.threshold = 0.5
if "use_local_ckpt" not in st.session_state:
st.session_state.use_local_ckpt = False
with st.sidebar:
st.subheader("Settings")
st.session_state.model_id = st.text_input("HF model id", value=st.session_state.model_id)
st.session_state.threshold = st.slider("Threshold", 0.0, 1.0, float(st.session_state.threshold), 0.01)
max_length = st.slider("Max length", 64, 512, 512, 16)
st.caption("Local finetuned weights (.pth) to override HF weights.")
model_dir = Path(__file__).parent/"models"
local_ckpts = sorted([p.name for p in model_dir.glob("*.pth")])
default_choice = local_ckpts[0] if local_ckpts else ""
ckpt_choice = st.selectbox("Checkpoint (from models dir)", options=[""] + local_ckpts, index=0)
ckpt_path_text = st.text_input("Or explicit checkpoint path (relative to models dir)", value="")
ckpt_path = str(model_dir/"bert_paper_classifier_arxiv.pth")
if ckpt_choice:
ckpt_path = str(model_dir/ckpt_choice)
if ckpt_path_text:
ckpt_path = str(model_dir/ckpt_path_text)
st.caption("Set env var `HF_TOKEN` to your Hugging Face token if you want use other models.")
DEFAULT_LABELS = [
"Computer Science",
"Economics",
"Electrical Engineering and Systems Science",
"Mathematics",
"Physics",
"Quantitative Biology",
"Quantitative Finance",
"Statistics",
]
num_labels = len(DEFAULT_LABELS)
device = "cuda" if torch.cuda.is_available() else "cpu"
tok, model = load_model(
st.session_state.model_id,
num_labels=num_labels,
checkpoint_path=ckpt_path,
ignore_mismatched_sizes=True,
)
model.to(device)
st.write(f"Device: `{device}`")
with st.form("predict"):
title = st.text_input("Title", value="")
summary = st.text_area("Summary / abstract", value="", height=220)
submitted = st.form_submit_button("Predict")
if submitted:
if not (title.strip() or summary.strip()):
st.warning("Provide at least a title or a summary.")
else:
preds, probs = predict(
model,
tok,
title=title,
summary=summary,
labels=DEFAULT_LABELS,
threshold=float(st.session_state.threshold),
max_length=int(max_length),
device=device,
)
if not preds:
st.info("No labels above threshold.")
else:
st.subheader("Predicted labels")
st.table([{"label": lab, "prob": round(p, 4)} for lab, p in preds])
with st.expander("All probabilities"):
st.json({lab: float(p) for lab, p in zip(DEFAULT_LABELS, probs)})