MurDanya's picture
Update app.py
2c37fc0 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
import torch
import numpy as np
import pandas as pd
import json
@st.cache_resource
def load_model():
repo_id = "MurDanya/ml-course-article-classifier"
model = AutoModelForSequenceClassification.from_pretrained(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
file_path = hf_hub_download(repo_id, "labels.json")
with open(file_path) as f:
labels = json.load(f)
id2label = {int(idx): label for idx, label in labels['id2label'].items()}
categories = labels['categories']
return tokenizer, model, id2label, categories
def get_top95(labels, probs):
sorted_indices = torch.argsort(probs, descending=True)
sorted_probs = probs[sorted_indices]
sorted_labels = [labels[i.item()] for i in sorted_indices]
cumulative = torch.cumsum(sorted_probs, dim=0)
cutoff = torch.where(cumulative >= 0.95)[0]
last_idx = cutoff[0].item() + 1 if len(cutoff) > 0 else len(sorted_probs)
return list(zip(sorted_labels[:last_idx], sorted_probs[:last_idx].tolist()))
# UI
st.set_page_config(page_title="Article Topic Classifier")
st.title("Article Topic Classifier")
st.markdown("Enter the **title** and optionally **abstract** of the article.")
title = st.text_input("Title", placeholder="e.g. Neural Networks for Quantum Physics")
abstract = st.text_area("Abstract (optional)", placeholder="e.g. We explore the application of neural nets...")
if st.button("Classify"):
if not title and not abstract:
st.warning("Please enter at least the title.")
else:
tokenizer, model, id2label, categories = load_model()
text = title + " - " + abstract if abstract else title
inputs = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits[0], dim=-1)
top_labels = get_top95(id2label, probs)
results = []
for label, prob in top_labels:
results.append({
"Category": categories[label],
"ID": label,
"Confidence": f"{prob * 100:.1f} %"
})
df = pd.DataFrame(results)
df.index += 1
st.markdown("### Top 95% Predicted Topics")
st.table(df)