Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import re | |
| import string | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.stem import WordNetLemmatizer | |
| from transformers import pipeline | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from wordcloud import WordCloud | |
| # Download required NLTK data | |
| nltk.download('stopwords') | |
| nltk.download('wordnet') | |
| nltk.download('omw-1.4') | |
| # Load Models | |
| news_classifier = pipeline("text-classification", model="Oneli/News_Classification") | |
| qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
| # Label Mapping | |
| label_mapping = { | |
| "LABEL_0": "Business", | |
| "LABEL_1": "Opinion", | |
| "LABEL_2": "Political Gossip", | |
| "LABEL_3": "Sports", | |
| "LABEL_4": "World News" | |
| } | |
| # Store classified article for QA | |
| context_storage = {"context": "", "bulk_context": "", "num_articles": 0} | |
| # Text Cleaning Functions | |
| def clean_text(text): | |
| text = text.lower() | |
| text = re.sub(f"[{string.punctuation}]", "", text) # Remove punctuation | |
| text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # Remove special characters | |
| words = text.split() # Tokenization without Punkt | |
| words = [word for word in words if word not in stopwords.words("english")] # Remove stopwords | |
| lemmatizer = WordNetLemmatizer() | |
| words = [lemmatizer.lemmatize(word) for word in words] # Lemmatize tokens | |
| return " ".join(words) | |
| # Define the functions | |
| def classify_text(text): | |
| cleaned_text = clean_text(text) | |
| result = news_classifier(cleaned_text)[0] | |
| category = label_mapping.get(result['label'], "Unknown") | |
| confidence = round(result['score'] * 100, 2) | |
| # Store context for QA | |
| context_storage["context"] = cleaned_text | |
| return category, f"Confidence: {confidence}%" | |
| def classify_csv(file): | |
| try: | |
| df = pd.read_csv(file, encoding="utf-8") | |
| text_column = df.columns[0] # Assume first column is the text column | |
| df[text_column] = df[text_column].astype(str).apply(clean_text) # Clean text column | |
| df["Decoded Prediction"] = df[text_column].apply(lambda x: label_mapping.get(news_classifier(x)[0]['label'], "Unknown")) | |
| df["Confidence"] = df[text_column].apply(lambda x: round(news_classifier(x)[0]['score'] * 100, 2)) | |
| # Store all text as a single context for QA | |
| context_storage["bulk_context"] = " ".join(df[text_column].dropna().astype(str).tolist()) | |
| context_storage["num_articles"] = len(df) | |
| output_file = "output.csv" | |
| df.to_csv(output_file, index=False) | |
| return df, output_file | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def chatbot_response(history, user_input, text_input=None, file_input=None): | |
| user_input = user_input.lower() | |
| context = "" | |
| if text_input: | |
| context += text_input | |
| if file_input: | |
| df, _ = classify_csv(file_input) | |
| context += context_storage["bulk_context"] | |
| if context: | |
| with st.spinner("Finding answer..."): | |
| result = qa_pipeline(question=user_input, context=context) | |
| answer = result["answer"] | |
| history.append([user_input, answer]) | |
| return history, answer | |
| # Function to generate word cloud from the 'content' column (from output CSV) | |
| def generate_word_cloud_from_output(df): | |
| # Assuming 'content' column is the first column after processing | |
| content_text = " ".join(df["content"].dropna().astype(str).tolist()) | |
| wordcloud = WordCloud(width=800, height=400, background_color="white").generate(content_text) | |
| return wordcloud | |
| # Function to generate bar graph for decoded predictions | |
| def generate_bar_graph(df): | |
| prediction_counts = df["Decoded Prediction"].value_counts() | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| prediction_counts.plot(kind='bar', ax=ax, color='skyblue') | |
| ax.set_title('Frequency of Decoded Predictions', fontsize=16) | |
| ax.set_xlabel('Category', fontsize=12) | |
| ax.set_ylabel('Frequency', fontsize=12) | |
| st.pyplot(fig) | |
| # Streamlit App Layout | |
| st.set_page_config(page_title="News Classifier", page_icon="π°") | |
| # Load image | |
| cover_image = Image.open("cover.png") # Ensure this image exists | |
| # Display image | |
| st.image(cover_image, use_container_width=True) | |
| # Custom styled caption | |
| st.markdown( | |
| "<h2 style='text-align: center; font-size: 32px;'>News Classifier π’</h2>", | |
| unsafe_allow_html=True | |
| ) | |
| # Section for Single Article Classification | |
| st.subheader("π° Single Article Classification") | |
| text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...") | |
| if st.button("π Classify"): | |
| if text_input: | |
| category, confidence = classify_text(text_input) | |
| st.write(f"Predicted Category: {category}") | |
| st.write(f"Confidence Level: {confidence}") | |
| # Generate word cloud for the cleaned text input | |
| wordcloud = generate_word_cloud_from_output(pd.DataFrame({"content": [text_input]})) # Create a DataFrame for single input | |
| st.image(wordcloud.to_array(), caption="Word Cloud for Text Input", use_container_width=True) | |
| else: | |
| st.warning("Please enter some text to classify.") | |
| # Section for Bulk CSV Classification | |
| st.subheader("π Bulk Classification (CSV)") | |
| file_input = st.file_uploader("Upload CSV File", type="csv") | |
| if file_input: | |
| df, output_file = classify_csv(file_input) | |
| if df is not None: | |
| st.dataframe(df) | |
| st.download_button( | |
| label="Download Processed CSV", | |
| data=open(output_file, 'rb').read(), | |
| file_name=output_file, | |
| mime="text/csv" | |
| ) | |
| # Generate word cloud for the 'content' column of the processed CSV data | |
| wordcloud = generate_word_cloud_from_output(df) | |
| st.image(wordcloud.to_array(), caption="Word Cloud for CSV Content", use_container_width=True) | |
| # Generate bar graph for decoded predictions frequency | |
| generate_bar_graph(df) | |
| else: | |
| st.error(f"Error processing file: {output_file}") | |
| # Section for Chatbot Interaction | |
| st.subheader("π¬ AI Chat Assistant") | |
| history = [] | |
| user_input = st.text_input("Ask about news classification or topics", placeholder="Type a message...") | |
| source_toggle = st.radio("Select Context Source", ["Single Article", "Bulk Classification"]) | |
| if st.button("β Send"): | |
| if not user_input and not file_input: | |
| st.warning("Please upload your file or provide text input for QA.") | |
| else: | |
| history, bot_response = chatbot_response( | |
| history, | |
| user_input, | |
| text_input=text_input if source_toggle == "Single Article" else None, | |
| file_input=file_input if source_toggle == "Bulk Classification" else None | |
| ) | |
| st.write("Chatbot Response:") | |
| for q, a in history: | |
| st.write(f"Q: {q}") | |
| st.write(f"A: {a}") | |