import numpy as np import streamlit as st from typing import Dict, List import torch import json from transformers import AutoModelForSequenceClassification, AutoTokenizer TITLE = "Enter the article or leave this field blank...." SUMMARY = "Enter summary of the article or leave it blank...." @st.cache_resource def load_model(): model = AutoModelForSequenceClassification.from_pretrained("Kulebyaka-kokosik/articles-classifier-model") tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased") model.eval() return model, tokenizer @st.cache_resource def load_label_to_topic(): with open("label_to_topic.json") as file: label_to_topic = json.load(file) return label_to_topic def inference( model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, title: str, summary: str ) -> torch.Tensor: tokenized = tokenizer(title, summary, padding='max_length', truncation=True, return_tensors="pt") with torch.no_grad(): logits = model(**tokenized).logits return logits def predict_topics( model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, title: str, summary: str, label_to_topic: Dict[str, str], top_k: int = 5, ) -> List[str]: logits = inference(model, tokenizer, title, summary) probs = torch.sigmoid(logits).squeeze(0) labels = np.argsort(-probs)[:top_k].tolist() topics = [label_to_topic[str(label)] for label in labels] return topics def main(): model, tokenizer = load_model() label_to_topic = load_label_to_topic() st.title("Articles classifier") title = st.text_input("Enter the article title in english or leave this field blank:") summary = st.text_area("Enter the article summary in english or leave this field blank:") k = st.number_input( "Top-K labels. Enter K:", min_value=1, max_value=20, value=5, step=1 ) if st.button('Classify'): if title == "" and summary == "": st.error("Please enter title and/or summary") return st.subheader("Most probable topics:") topics = predict_topics(model, tokenizer, title, summary, label_to_topic=label_to_topic, top_k=k) for topic in sorted(topics): st.write(topic) main()