articles_classifier / src /streamlit_app.py
Kulebyaka-kokosik
src/streamlit_app.py: or -> and
36588ff
import numpy as np
import streamlit as st
from typing import Dict, List
import torch
import json
from transformers import AutoModelForSequenceClassification, AutoTokenizer
TITLE = "Enter the article or leave this field blank...."
SUMMARY = "Enter summary of the article or leave it blank...."
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained("Kulebyaka-kokosik/articles-classifier-model")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased")
model.eval()
return model, tokenizer
@st.cache_resource
def load_label_to_topic():
with open("label_to_topic.json") as file:
label_to_topic = json.load(file)
return label_to_topic
def inference(
model: AutoModelForSequenceClassification,
tokenizer: AutoTokenizer,
title: str,
summary: str
) -> torch.Tensor:
tokenized = tokenizer(title, summary, padding='max_length', truncation=True, return_tensors="pt")
with torch.no_grad():
logits = model(**tokenized).logits
return logits
def predict_topics(
model: AutoModelForSequenceClassification,
tokenizer: AutoTokenizer,
title: str,
summary: str,
label_to_topic: Dict[str, str],
top_k: int = 5,
) -> List[str]:
logits = inference(model, tokenizer, title, summary)
probs = torch.sigmoid(logits).squeeze(0)
labels = np.argsort(-probs)[:top_k].tolist()
topics = [label_to_topic[str(label)] for label in labels]
return topics
def main():
model, tokenizer = load_model()
label_to_topic = load_label_to_topic()
st.title("Articles classifier")
title = st.text_input("Enter the article title in english or leave this field blank:")
summary = st.text_area("Enter the article summary in english or leave this field blank:")
k = st.number_input(
"Top-K labels. Enter K:",
min_value=1,
max_value=20,
value=5,
step=1
)
if st.button('Classify'):
if title == "" and summary == "":
st.error("Please enter title and/or summary")
return
st.subheader("Most probable topics:")
topics = predict_topics(model, tokenizer, title, summary, label_to_topic=label_to_topic, top_k=k)
for topic in sorted(topics):
st.write(topic)
main()