| import numpy as np |
| import streamlit as st |
| from typing import Dict, List |
| import torch |
| import json |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| TITLE = "Enter the article or leave this field blank...." |
| SUMMARY = "Enter summary of the article or leave it blank...." |
|
|
| @st.cache_resource |
| def load_model(): |
| model = AutoModelForSequenceClassification.from_pretrained("Kulebyaka-kokosik/articles-classifier-model") |
| tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased") |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| @st.cache_resource |
| def load_label_to_topic(): |
| with open("label_to_topic.json") as file: |
| label_to_topic = json.load(file) |
| return label_to_topic |
|
|
|
|
| def inference( |
| model: AutoModelForSequenceClassification, |
| tokenizer: AutoTokenizer, |
| title: str, |
| summary: str |
| ) -> torch.Tensor: |
| tokenized = tokenizer(title, summary, padding='max_length', truncation=True, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| logits = model(**tokenized).logits |
| return logits |
|
|
|
|
| def predict_topics( |
| model: AutoModelForSequenceClassification, |
| tokenizer: AutoTokenizer, |
| title: str, |
| summary: str, |
| label_to_topic: Dict[str, str], |
| top_k: int = 5, |
| ) -> List[str]: |
| logits = inference(model, tokenizer, title, summary) |
|
|
| probs = torch.sigmoid(logits).squeeze(0) |
| labels = np.argsort(-probs)[:top_k].tolist() |
| topics = [label_to_topic[str(label)] for label in labels] |
| return topics |
|
|
|
|
|
|
| def main(): |
| model, tokenizer = load_model() |
| label_to_topic = load_label_to_topic() |
|
|
| st.title("Articles classifier") |
|
|
| title = st.text_input("Enter the article title in english or leave this field blank:") |
| summary = st.text_area("Enter the article summary in english or leave this field blank:") |
| k = st.number_input( |
| "Top-K labels. Enter K:", |
| min_value=1, |
| max_value=20, |
| value=5, |
| step=1 |
| ) |
|
|
| if st.button('Classify'): |
| if title == "" and summary == "": |
| st.error("Please enter title and/or summary") |
| return |
| st.subheader("Most probable topics:") |
| topics = predict_topics(model, tokenizer, title, summary, label_to_topic=label_to_topic, top_k=k) |
| for topic in sorted(topics): |
| st.write(topic) |
|
|
| main() |
|
|