| import sys |
| import streamlit as st |
| from model_utils import ( |
| load_model_and_tokenizer, |
| predict_category |
| ) |
|
|
| sys.modules["torch.classes"] = None |
|
|
| st.set_page_config( |
| page_title="ArXiv Category Classifier", |
| page_icon="๐", |
| layout="centered" |
| ) |
|
|
| @st.cache_resource |
| def init_model(): |
| model_path = "saitsharipov/arxiv-category-model" |
| model, tokenizer, id2label, label2id = load_model_and_tokenizer(model_path) |
| return model, tokenizer, id2label, label2id |
|
|
| model, tokenizer, id2label, label2id = init_model() |
|
|
| st.title("๐ ArXiv Category Classifier") |
| st.caption("AI-app to predict categories for scientific articles using a modern NLP model.") |
| st.divider() |
|
|
| default_title = "Attention Is All You Need" |
| default_abstract = ( |
| "The dominant sequence transduction models are based on complex recurrent " |
| "or convolutional neural networks in an encoder-decoder configuration. " |
| "The best performing models also connect the encoder and decoder through an attention mechanism." |
| ) |
|
|
| title_input = st.text_area("**Title**", value=default_title, height=80) |
| abstract_input = st.text_area("**Abstract**", value=default_abstract, height=200) |
|
|
| if st.button("๐ฎ Predict Category"): |
| with st.spinner("Predicting category..."): |
| predicted_label = predict_category(title_input, abstract_input, tokenizer, model, id2label) |
| st.success(f"Predicted Category: **{predicted_label}**") |
|
|