Ginidu2003's picture
Update src/streamlit_app.py
6472126 verified
import streamlit as st
import pandas as pd
import torch # ← This was missing
from transformers import pipeline
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re
import string
# ====================== PREPROCESSING ======================
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
if not isinstance(text, str):
return ""
text = text.lower()
text = re.sub(f'[{string.punctuation}]', ' ', text)
text = re.sub(r'[^a-z\s]', ' ', text)
tokens = nltk.word_tokenize(text)
tokens = [word for word in tokens if word not in stop_words]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# ====================== LOAD MODEL ======================
@st.cache_resource
def load_model():
model_name = "Ginidu2003/Distilbert-Base-News-classifier" # ← Change if your model name is different
return pipeline(
"text-classification",
model=model_name,
device=0 if torch.cuda.is_available() else -1
)
classifier = load_model()
# ====================== STREAMLIT APP ======================
st.title("πŸ“° Daily Mirror News Classifier")
st.subheader("Classify news into Business, Opinion, Political Gossip, Sports, or World News")
st.markdown("**Upload a CSV file** with a column named `content`")
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.write("### Preview of uploaded data")
st.dataframe(df.head())
if 'content' not in df.columns:
st.error("Your CSV must have a column named 'content'")
else:
with st.spinner("Preprocessing and classifying..."):
df['clean_content'] = df['content'].apply(preprocess_text)
predictions = []
for text in df['clean_content']:
if text.strip() == "":
predictions.append("Unknown")
else:
result = classifier(text)[0]
predictions.append(result['label'])
df['class'] = predictions
df = df.drop(columns=['clean_content'], errors='ignore')
st.success("βœ… Classification completed!")
st.write("### Preview of classified data")
st.dataframe(df.head())
# Download button
csv = df.to_csv(index=False).encode('utf-8')
st.download_button(
label="πŸ“₯ Download output.csv",
data=csv,
file_name="output.csv",
mime="text/csv"
)
st.caption("Built for Text Analytics Assignment - Section 02")