Ashendilantha's picture
Upload app.py
e964457 verified
raw
history blame
5.98 kB
import streamlit as st
import pandas as pd
import re
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from transformers import pipeline
from PIL import Image
# Load Models
news_classifier = pipeline("text-classification", model="Oneli/News_Classification")
import streamlit as st
import pandas as pd
import re
import string
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from transformers import pipeline
from PIL import Image
# Download required NLTK data
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')
# Load Models
news_classifier = pipeline("text-classification", model="Oneli/News_Classification")
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
# Label Mapping
label_mapping = {
"LABEL_0": "Business",
"LABEL_1": "Opinion",
"LABEL_2": "Political Gossip",
"LABEL_3": "Sports",
"LABEL_4": "World News"
}
# Store classified article for QA
context_storage = {"context": "", "bulk_context": "", "num_articles": 0}
# Text Cleaning Functions
def clean_text(text):
text = text.lower()
text = re.sub(f"[{string.punctuation}]", "", text) # Remove punctuation
text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # Remove special characters
tokens = word_tokenize(text)
tokens = [word for word in tokens if word not in stopwords.words("english")] # Remove stopwords
lemmatizer = WordNetLemmatizer()
tokens = [lemmatizer.lemmatize(word) for word in tokens] # Lemmatize tokens
return " ".join(tokens)
# Define the functions
def classify_text(text):
cleaned_text = clean_text(text)
result = news_classifier(cleaned_text)[0]
category = label_mapping.get(result['label'], "Unknown")
confidence = round(result['score'] * 100, 2)
# Store context for QA
context_storage["context"] = cleaned_text
return category, f"Confidence: {confidence}%"
def classify_csv(file):
try:
df = pd.read_csv(file, encoding="utf-8")
text_column = df.columns[0] # Assume first column is the text column
df[text_column] = df[text_column].astype(str).apply(clean_text) # Clean text column
df["Encoded Prediction"] = df[text_column].apply(lambda x: news_classifier(x)[0]['label'])
df["Decoded Prediction"] = df["Encoded Prediction"].map(label_mapping)
df["Confidence"] = df[text_column].apply(lambda x: round(news_classifier(x)[0]['score'] * 100, 2))
# Store all text as a single context for QA
context_storage["bulk_context"] = " ".join(df[text_column].dropna().astype(str).tolist())
context_storage["num_articles"] = len(df)
output_file = "output.csv"
df.to_csv(output_file, index=False)
return df, output_file
except Exception as e:
return None, f"Error: {str(e)}"
def chatbot_response(history, user_input, source):
user_input = user_input.lower()
context = context_storage["context"] if source == "Single Article" else context_storage["bulk_context"]
num_articles = context_storage["num_articles"]
if "number of articles" in user_input or "how many articles" in user_input:
answer = f"There are {num_articles} articles in the uploaded CSV."
history.append([user_input, answer])
return history, ""
if context:
result = qa_pipeline(question=user_input, context=context)
answer = result["answer"]
history.append([user_input, answer])
return history, ""
responses = {
"hello": "πŸ‘‹ Hello! How can I assist you with news today?",
"hi": "😊 Hi there! What do you want to know about news?",
"how are you": "πŸ€– I'm just a bot, but I'm here to help!",
"thank you": "πŸ™ You're welcome! Let me know if you need anything else.",
"news": "πŸ“° I can classify news into Business, Sports, Politics, and more!",
}
response = responses.get(user_input,
"πŸ€” I'm here to help with news classification and general info. Ask me about news topics!")
history.append([user_input, response])
return history, ""
# Streamlit App Layout
st.set_page_config(page_title="News Classifier", page_icon="πŸ“°")
cover_image = Image.open("cover.png") # Ensure this image exists
st.image(cover_image, caption="News Classifier πŸ“’", use_column_width=True)
# Section for Single Article Classification
st.subheader("πŸ“° Single Article Classification")
text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...")
if st.button("πŸ” Classify"):
if text_input:
category, confidence = classify_text(text_input)
st.write(f"*Predicted Category:* {category}")
st.write(f"*Confidence Level:* {confidence}")
else:
st.warning("Please enter some text to classify.")
# Section for Bulk CSV Classification
st.subheader("πŸ“‚ Bulk Classification (CSV)")
file_input = st.file_uploader("Upload CSV File", type="csv")
if file_input:
df, output_file = classify_csv(file_input)
if df is not None:
st.dataframe(df)
st.download_button(
label="Download Processed CSV",
data=open(output_file, 'rb').read(),
file_name=output_file,
mime="text/csv"
)
else:
st.error(f"Error processing file: {output_file}")
# Section for Chatbot Interaction
st.subheader("πŸ’¬ AI Chat Assistant")
history = []
user_input = st.text_input("Ask about news classification or topics", placeholder="Type a message...")
source_toggle = st.radio("Select Context Source", ["Single Article", "Bulk Classification"])
if st.button("βœ‰ Send"):
history, bot_response = chatbot_response(history, user_input, source_toggle)
st.write("*Chatbot Response:*")
for q, a in history:
st.write(f"*Q:* {q}")
st.write(f"*A:* {a}")