import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import hf_hub_download import torch import numpy as np import pandas as pd import json @st.cache_resource def load_model(): repo_id = "MurDanya/ml-course-article-classifier" model = AutoModelForSequenceClassification.from_pretrained(repo_id) tokenizer = AutoTokenizer.from_pretrained(repo_id) file_path = hf_hub_download(repo_id, "labels.json") with open(file_path) as f: labels = json.load(f) id2label = {int(idx): label for idx, label in labels['id2label'].items()} categories = labels['categories'] return tokenizer, model, id2label, categories def get_top95(labels, probs): sorted_indices = torch.argsort(probs, descending=True) sorted_probs = probs[sorted_indices] sorted_labels = [labels[i.item()] for i in sorted_indices] cumulative = torch.cumsum(sorted_probs, dim=0) cutoff = torch.where(cumulative >= 0.95)[0] last_idx = cutoff[0].item() + 1 if len(cutoff) > 0 else len(sorted_probs) return list(zip(sorted_labels[:last_idx], sorted_probs[:last_idx].tolist())) # UI st.set_page_config(page_title="Article Topic Classifier") st.title("Article Topic Classifier") st.markdown("Enter the **title** and optionally **abstract** of the article.") title = st.text_input("Title", placeholder="e.g. Neural Networks for Quantum Physics") abstract = st.text_area("Abstract (optional)", placeholder="e.g. We explore the application of neural nets...") if st.button("Classify"): if not title and not abstract: st.warning("Please enter at least the title.") else: tokenizer, model, id2label, categories = load_model() text = title + " - " + abstract if abstract else title inputs = tokenizer(text, return_tensors="pt", truncation=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits[0], dim=-1) top_labels = get_top95(id2label, probs) results = [] for label, prob in top_labels: results.append({ "Category": categories[label], "ID": label, "Confidence": f"{prob * 100:.1f} %" }) df = pd.DataFrame(results) df.index += 1 st.markdown("### Top 95% Predicted Topics") st.table(df)