saitsharipov's picture
Update app.py
65f80b0 verified
import sys
import streamlit as st
from model_utils import (
load_model_and_tokenizer,
predict_category
)
sys.modules["torch.classes"] = None
st.set_page_config(
page_title="ArXiv Category Classifier",
page_icon="๐Ÿ“—",
layout="centered"
)
@st.cache_resource
def init_model():
model_path = "saitsharipov/arxiv-category-model"
model, tokenizer, id2label, label2id = load_model_and_tokenizer(model_path)
return model, tokenizer, id2label, label2id
model, tokenizer, id2label, label2id = init_model()
st.title("๐Ÿ“— ArXiv Category Classifier")
st.caption("AI-app to predict categories for scientific articles using a modern NLP model.")
st.divider()
default_title = "Attention Is All You Need"
default_abstract = (
"The dominant sequence transduction models are based on complex recurrent "
"or convolutional neural networks in an encoder-decoder configuration. "
"The best performing models also connect the encoder and decoder through an attention mechanism."
)
title_input = st.text_area("**Title**", value=default_title, height=80)
abstract_input = st.text_area("**Abstract**", value=default_abstract, height=200)
if st.button("๐Ÿ”ฎ Predict Category"):
with st.spinner("Predicting category..."):
predicted_label = predict_category(title_input, abstract_input, tokenizer, model, id2label)
st.success(f"Predicted Category: **{predicted_label}**")