new-classifier / app.py
dinusha11's picture
Update app.py
07d2160 verified
pip install torch
import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the fine-tuned model
MODEL_NAME = "dinusha11/finetuned-distilbert-news"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Define label mapping
LABEL_MAPPING = {0: "Business", 1: "Opinion", 2: "Sports", 3: "Political_gossip", 4: "World_news"}
# Function to classify text
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return LABEL_MAPPING[predicted_class]
# Streamlit UI
st.title("News Classification App")
st.write("Upload a CSV file with news excerpts, and the model will classify each record.")
# File uploader
uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
if "text" not in df.columns:
st.error("CSV must contain a 'text' column.")
else:
# Preprocess text
df["text"] = df["text"].fillna("").str.strip().str.lower()
# Apply classification
df["class"] = df["text"].apply(classify_text)
# Download output CSV
output_csv = df.to_csv(index=False).encode("utf-8")
st.download_button("Download Results", data=output_csv, file_name="output.csv", mime="text/csv")
st.write("Classification Complete! Download your file above.")