Spaces:
Sleeping
Sleeping
File size: 1,640 Bytes
07d2160 5b4cc33 25c8e68 5b4cc33 89f2ab5 5b4cc33 89f2ab5 5b4cc33 89f2ab5 5b4cc33 89f2ab5 5b4cc33 89f2ab5 5b4cc33 89f2ab5 5b4cc33 89f2ab5 5b4cc33 1eee1d7 89f2ab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
pip install torch
import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the fine-tuned model
MODEL_NAME = "dinusha11/finetuned-distilbert-news"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Define label mapping
LABEL_MAPPING = {0: "Business", 1: "Opinion", 2: "Sports", 3: "Political_gossip", 4: "World_news"}
# Function to classify text
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return LABEL_MAPPING[predicted_class]
# Streamlit UI
st.title("News Classification App")
st.write("Upload a CSV file with news excerpts, and the model will classify each record.")
# File uploader
uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
if "text" not in df.columns:
st.error("CSV must contain a 'text' column.")
else:
# Preprocess text
df["text"] = df["text"].fillna("").str.strip().str.lower()
# Apply classification
df["class"] = df["text"].apply(classify_text)
# Download output CSV
output_csv = df.to_csv(index=False).encode("utf-8")
st.download_button("Download Results", data=output_csv, file_name="output.csv", mime="text/csv")
st.write("Classification Complete! Download your file above.")
|