| from typing import Tuple, Any, List
|
|
|
| import numpy as np
|
| import streamlit as st
|
| import os
|
| import torch
|
| from pathlib import Path
|
| from os import PathLike
|
|
|
| from peft import AutoPeftModelForSequenceClassification
|
| from transformers import AutoTokenizer, AutoConfig
|
| from typing import Dict
|
| import requests
|
| from bs4 import BeautifulSoup
|
| import json
|
| import pandas as pd
|
|
|
| from deep_translator import (GoogleTranslator,
|
| MyMemoryTranslator,
|
| single_detection)
|
|
|
|
|
| os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| NUM_LABELS = 149
|
|
|
| BASE_DIR = Path(__file__).resolve().parent.parent
|
| DATA_DIR = BASE_DIR / "data"
|
| MODEL_DIR = BASE_DIR / "model"
|
|
|
|
|
| @st.cache_resource(show_spinner="Загрузка модели...")
|
| def load_model(
|
| model_name: str = "oracat/bert-paper-classifier-arxiv",
|
| model_path: str | PathLike[str]= MODEL_DIR / 'bert-paper-classifier-arxiv'
|
| ) -> Tuple[Any, Any, AutoConfig]:
|
|
|
| model = AutoPeftModelForSequenceClassification.from_pretrained(
|
| model_path,
|
| num_labels=149,
|
| problem_type="multi_label_classification",
|
| ignore_mismatched_sizes=True
|
| )
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| model_cfg = AutoConfig.from_pretrained(model_name)
|
|
|
| model = model.to(device)
|
| model.eval()
|
| return model, tokenizer, model_cfg
|
|
|
|
|
| @st.cache_data(show_spinner='Получение категорий с Arxiv...')
|
| def fetch_arxiv_cat_describes(url: str = "https://arxiv.org/category_taxonomy") -> Dict[str, str]:
|
| resp = requests.get(url)
|
| resp.raise_for_status()
|
| soup = BeautifulSoup(resp.text, "html5lib")
|
|
|
| cat2descr = {}
|
| for head in soup.select("h2.accordion-head"):
|
|
|
| body = head.find_next_sibling("div", class_="accordion-body")
|
| if body is None:
|
| continue
|
|
|
| for block in body.select("div.columns.divided"):
|
| h4 = block.find("h4")
|
| if h4 is None:
|
| continue
|
|
|
| cat_code = h4.contents[0].text.strip()
|
| cat2descr[cat_code] = h4.contents[1].text.strip('()')
|
|
|
| return cat2descr
|
|
|
|
|
| @st.cache_data
|
| def get_ids_to_cats(json_path=DATA_DIR / 'ids2cat.json') -> Dict[int, str]:
|
| if not json_path.exists():
|
| raise FileNotFoundError(json_path)
|
|
|
| with open(json_path, 'r') as f:
|
| ids2cat = json.load(f)
|
| ids2cat = {int(k): v for k, v in ids2cat.items()}
|
|
|
| return ids2cat
|
|
|
|
|
| @st.cache_data(show_spinner=False)
|
| def translate_text(text: str, translator_name: str, source: str = "auto") -> str:
|
| if not text or not text.strip():
|
| return text
|
|
|
| if translator_name == "Google Translator":
|
| return GoogleTranslator(source=source, target='en').translate(text)
|
|
|
| elif translator_name == "MyMemory Translator":
|
| return MyMemoryTranslator(source=source, target='en-GB').translate(text)
|
|
|
| raise ValueError(f"Неизвестный переводчик: {translator_name}")
|
|
|
|
|
| @torch.inference_mode()
|
| def predict(text: str) -> List[Tuple[int, float]]:
|
|
|
| inputs = tokenizer(
|
| text,
|
| return_tensors="pt",
|
| truncation=True,
|
| padding=True,
|
| max_length=model_cfg.max_position_embeddings,
|
| )
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
| outputs = model(**inputs)
|
| probs = torch.sigmoid(outputs.logits[0]).cpu().tolist()
|
| return [(i, p) for i, p in enumerate(probs)]
|
|
|
|
|
| def go_to_phase1():
|
| st.session_state.step = 1
|
|
|
|
|
| def go_to_phase2():
|
| if ((st.session_state.title_input is None or len(st.session_state.title_input) == 0 )
|
| and (st.session_state.abstract_input is None or len(st.session_state.abstract_input) == 0)):
|
| st.toast('Пустой заголовок и текст', icon="🚫")
|
| return
|
|
|
| st.session_state.step = 2
|
|
|
| st.session_state.saved_title = st.session_state.title_input
|
| st.session_state.saved_abstract = st.session_state.abstract_input
|
| st.session_state.saved_k = st.session_state.k_input
|
| st.session_state.saved_translator = st.session_state.translator_choice
|
|
|
|
|
| def clear_form():
|
| st.session_state.title_input = None
|
| st.session_state.abstract_input = None
|
| st.session_state.k_input = 10
|
| st.session_state.translator_choice = "Google Translator"
|
|
|
| st.session_state.saved_title = None
|
| st.session_state.saved_abstract = None
|
| st.session_state.saved_k = 10
|
| st.session_state.saved_translator = "Google Translator"
|
|
|
|
|
| if 'saved_title' not in st.session_state:
|
| st.session_state.saved_title = None
|
|
|
| if 'saved_abstract' not in st.session_state:
|
| st.session_state.saved_abstract = None
|
|
|
| if 'saved_k' not in st.session_state:
|
| st.session_state.saved_k = 10
|
|
|
| if 'saved_translator' not in st.session_state:
|
| st.session_state.saved_translator = "Google Translator"
|
|
|
| if 'step' not in st.session_state:
|
| st.session_state.step = 1
|
|
|
|
|
| model, tokenizer, model_cfg = load_model()
|
| cat2descr = fetch_arxiv_cat_describes()
|
| ids2cat = get_ids_to_cats()
|
|
|
|
|
| st.markdown("""
|
| <style>
|
| html, body, [class*="css"] {
|
| font-family: "Nunito Sans", sans-serif;
|
| }
|
|
|
| h1, h2, h3, h4, h5, h6,
|
| [data-testid="stMarkdownContainer"] h1,
|
| [data-testid="stMarkdownContainer"] h2,
|
| [data-testid="stMarkdownContainer"] h3 {
|
| font-family: "Instrument Serif", serif !important;
|
| font-weight: 400 !important;
|
| }
|
|
|
| code, pre, kbd, samp {
|
| font-family: "JetBrains Mono", monospace !important;
|
| }
|
| </style>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| with st.container(border=True):
|
| st.caption("О ПРОЕКТЕ")
|
| st.subheader("BERT Multi-label классификатор arXiv-статей")
|
|
|
| st.markdown(
|
| """
|
| Проект представляет собой систему **multi-label классификации научных статей arXiv**
|
| по **заголовку** и **аннотации**.
|
|
|
| Датасет был собран напрямую через **arXiv API**: статьи выгружались по категориям,
|
| затем сопоставлялись с **официальной таксономией arXiv** и дедуплицировались
|
| по нормализованному идентификатору. В результате был получен корпус
|
| объёмом **более 1M+ примеров**.
|
| """
|
| )
|
|
|
| st.markdown(
|
| """
|
| Модель предсказывает **сразу несколько тегов** для одной статьи.
|
| Чтобы система была устойчивой к неполным данным, в обучении используются
|
| два режима входа:
|
|
|
| - **Только заголовок**
|
| - **Заголовок + аннотация**
|
|
|
| Выбор режима происходит прямо в **batch collator**, поэтому модель учится
|
| работать и с полным описанием статьи, и с сокращённым вариантом.
|
| """
|
| )
|
|
|
| st.markdown(
|
| """
|
| В качестве основы используется **BERT для sequence classification**,
|
| дообучаемый с помощью **LoRA**.
|
|
|
| Для борьбы с сильным дисбалансом классов применяется
|
| **взвешенная binary cross-entropy**: вес каждого класса рассчитывается
|
| по соотношению отрицательных и положительных примеров.
|
| Итоговая модель развёрнута в **Streamlit-приложении**
|
| для интерактивного предсказания arXiv-категорий.
|
| """
|
| )
|
|
|
| st.divider()
|
|
|
| c1, c2, c3 = st.columns(3)
|
| c1.metric("Датасет", "1M+ статей")
|
| c2.metric("Обучение", "~6 часов")
|
| c3.metric("GPU", "H100")
|
|
|
| st.caption(
|
| "Trainable params: 999,317 · All params: 110,596,138 · Trainable%: 0.9036"
|
| )
|
|
|
| st.markdown("### Гиперпараметры обучения")
|
|
|
| train_params = pd.DataFrame(
|
| [
|
| ("Базовая модель", "oracat/bert-paper-classifier-arxiv"),
|
| ("Тип задачи", "Multi-label classification"),
|
| ("Число классов", "149"),
|
| ("Разделение датасета", "90% train / 10% test"),
|
| ("Batch size (train)", "512"),
|
| ("Batch size (eval)", "512"),
|
| ("Число эпох", "10"),
|
| ("Learning rate", "5e-4"),
|
| ("LR scheduler", "cosine"),
|
| ("Warmup steps", "10"),
|
| ("Оптимизатор", "adamw_torch_fused"),
|
| ("Weight decay", "0.001"),
|
| ("Gradient accumulation steps", "1"),
|
| ("Mixed precision", "bf16"),
|
| ("LoRA target modules", "query, key, value"),
|
| ("LoRA rank (r)", "16"),
|
| ("LoRA alpha", "32"),
|
| ("LoRA dropout", "0.05"),
|
| ("use_rslora", "True"),
|
| ("modules_to_save", "classifier"),
|
| ("Функция потерь", "Weighted BCEWithLogitsLoss"),
|
| ],
|
| columns=["Гиперпараметр", "Значение"]
|
| )
|
|
|
| st.dataframe(train_params, use_container_width=True, hide_index=True)
|
|
|
| if st.session_state.step == 1:
|
| st.subheader('Часть 1: Ввод статьи')
|
| with st.form(key='insert form', height='content'):
|
|
|
| title = st.text_input(
|
| label='Введите заголовок статьи',
|
| placeholder='Заголовок статьи',
|
| key='title_input',
|
| value=st.session_state.saved_title
|
| )
|
|
|
| abstract = st.text_area(
|
| label='Введите текст статьи',
|
| placeholder='Текст статьи',
|
| key='abstract_input',
|
| value = st.session_state.saved_abstract,
|
| height='content'
|
| )
|
|
|
| slider_help = 'Параметр регулирует вывод K наиболее вероятных тэгов из таксономии arxiv для статьи '
|
| k = st.slider(
|
| label='Введите top-K',
|
| min_value=1,
|
| max_value=20,
|
| value = st.session_state.saved_k,
|
| key='k_input',
|
| help=slider_help
|
| )
|
|
|
| options = ["Google Translator", "MyMemory Translator"]
|
| selectbox_help = ('Модель была обучена понимать только английский, поэтому если текст на другом языке'
|
| ' или на смеси языков, то его необходимо сначала перевести.')
|
| option = st.selectbox(
|
| label='Выбор средства перевода',
|
| options=options,
|
| key='translator_choice',
|
| help=selectbox_help
|
| )
|
|
|
| col1, col2, col3 = st.columns([1, 1, 5], gap="small")
|
|
|
| with col1:
|
| pressed_input = st.form_submit_button(
|
| "Ввод",
|
| width="stretch",
|
| on_click=go_to_phase2
|
| )
|
|
|
| with col2:
|
| pressed_clear = st.form_submit_button(
|
| "Очистка",
|
| width="stretch",
|
| on_click=clear_form
|
| )
|
|
|
| elif st.session_state.step == 2:
|
|
|
| if np.random.rand(1) < 0.5:
|
| st.balloons()
|
| else:
|
| st.snow()
|
|
|
| st.subheader('Часть 2: Результаты классификации')
|
|
|
| title = (
|
| st.session_state.saved_title
|
| if st.session_state.saved_title is not None
|
| else ''
|
| )
|
|
|
| abstract = (
|
| st.session_state.saved_abstract
|
| if st.session_state.saved_abstract is not None
|
| else ''
|
| )
|
|
|
| k = (
|
| st.session_state.saved_k
|
| if st.session_state.saved_k is not None
|
| else 10
|
| )
|
|
|
| translator_name = (
|
| st.session_state.saved_translator
|
| if st.session_state.saved_translator is not None
|
| else "Google Translator"
|
| )
|
|
|
| if title and abstract:
|
| input_text = title + '\n\n' + abstract
|
| elif title:
|
| input_text = title
|
| else:
|
| input_text = abstract
|
|
|
| lang = single_detection(input_text, api_key='db46a38fbac48c2b159384593d450933')
|
| translated_text = input_text
|
| if lang != 'en':
|
| with st.spinner(f"Перевод через {translator_name}..."):
|
| try:
|
| translated_text = translate_text(
|
| text=input_text,
|
| translator_name=translator_name,
|
| source='auto'
|
| )
|
| except Exception as e:
|
| st.error(f"Ошибка перевода: {e}")
|
| st.stop()
|
|
|
| with st.spinner("Модель анализирует текст..."):
|
| probabilities = sorted(predict(translated_text), key=lambda x: x[1], reverse=True)
|
| probs_k = [(ids2cat[x[0]], x[1]) for x in probabilities[:k]]
|
|
|
| with st.container(border=True):
|
| st.caption("Лучшие совпадения")
|
| c1, c2 = st.columns([3, 1])
|
|
|
| with c1:
|
| best_cat = cat2descr[probs_k[0][0]]
|
| st.markdown(f"# **{best_cat}**")
|
| st.caption(f"Arxiv Tag: {probs_k[0][0]}")
|
|
|
| with c2:
|
| best_p = probs_k[0][1]
|
| st.markdown(f"# _**{round(best_p * 100, 2)}%**_")
|
|
|
| with st.container(border=True):
|
| st.caption(f"Наиболее подходящие {k}/{NUM_LABELS} категорий")
|
| st.markdown(
|
| "<hr style='margin: 8px 0; border: none; border-top: 1px solid #ddd;'>",
|
| unsafe_allow_html=True
|
| )
|
| rows = []
|
| for cat, prob in probs_k:
|
| rows.append({'Category': cat2descr[cat], 'Probability': f'{round(prob * 100, 2)}%', 'Arxiv Tag': cat})
|
| df = pd.DataFrame(rows)
|
| st.dataframe(df)
|
|
|
| st.button(label='Назад', on_click=go_to_phase1, width="stretch") |