News_Classifier / app.py
Ginidu2003's picture
Update app.py
339e6ca verified
import gradio as gr
import pandas as pd
import torch
from transformers import pipeline
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re
import string
import matplotlib.pyplot as plt
# ====================== NLTK SETUP ======================
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
if not isinstance(text, str):
return ""
text = text.lower()
punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
text = re.sub(f"[{punct_to_remove}]", " ", text)
tokens = nltk.word_tokenize(text)
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
# ====================== BEAUTIFUL COLORED BAR CHART ======================
def create_colored_bar_chart(category_counts):
if category_counts is None or len(category_counts) == 0:
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
return fig
categories = category_counts["Category"]
counts = category_counts["Count"]
# Nice modern color palette
colors = ['#3498DB', '#E67E22', '#9B59B6', '#2ECC71', '#E74C3C']
fig, ax = plt.subplots(figsize=(11, 6))
bars = ax.bar(categories, counts, color=colors, edgecolor='white', linewidth=0.8)
# Add value on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, height + 0.8,
str(int(height)), ha='center', va='bottom', fontsize=13, fontweight='bold')
ax.set_title("Category Distribution Across 5 Classes", fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel("Category", fontsize=12)
ax.set_ylabel("Count", fontsize=12)
plt.xticks(rotation=15)
plt.tight_layout()
return fig
# ====================== CLASSIFICATION FUNCTION ======================
@torch.no_grad()
def classify_csv(file):
try:
df = pd.read_csv(file)
if 'content' not in df.columns:
return "Error: CSV must have a column named 'content'", None, None
df['clean_content'] = df['content'].apply(preprocess_text)
classifier = pipeline("text-classification", model=classifier_model, device=-1)
predictions = []
for text in df['clean_content']:
if not text.strip():
predictions.append("Unknown")
else:
result = classifier(text)[0]
predictions.append(result['label'])
df['class'] = predictions
df = df.drop(columns=['clean_content'], errors='ignore')
output_file = "output.csv"
df.to_csv(output_file, index=False)
category_counts = df['class'].value_counts().reset_index()
category_counts.columns = ["Category", "Count"]
fig = create_colored_bar_chart(category_counts)
return f"βœ… Success! Classified {len(df)} rows", output_file, fig
except Exception as e:
return f"❌ Error: {str(e)}", None, None
# ====================== Q&A FUNCTION ======================
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
def answer_question(news_content, question):
if not news_content.strip() or not question.strip():
return "Please enter both news content and a question."
try:
inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = qa_model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits) + 1
answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
return f"**Answer:** :- {answer.strip()}\n\n**Confidence:** :- {confidence:.2%}"
except Exception as e:
return f"Error: {str(e)}"
# ====================== BEAUTIFUL UI ======================
with gr.Blocks(
title="News Classifier & Question Answering App",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1250px;
min-width: 1000px;
margin: auto;
background: linear-gradient(135deg, #0f172a 0%, #1e2937 100%);
}
h1 {
text-align: center;
font-size: 2.9rem;
background: linear-gradient(90deg, #60a5fa, #c084fc);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 10px;
}
/* Upload Box - Gradient */
.file-upload {
background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important;
border-radius: 16px;
border: none;
}
/* Classify News Button - Gradient */
button.primary {
background: linear-gradient(90deg, #6366f1, #a855f7) !important;
border: none;
font-weight: 700;
font-size: 1.1rem;
padding: 14px 0;
border-radius: 12px;
transition: all 0.3s ease;
}
button.primary:hover {
transform: translateY(-3px);
box-shadow: 0 15px 25px rgba(139, 92, 246, 0.5);
}
/* Status Box - Gradient */
.status-box {
background: linear-gradient(135deg, #10b981, #34d399) !important;
color: white;
border-radius: 12px;
}
/* Download Button - Gradient */
button.secondary {
background: linear-gradient(90deg, #ec4899, #f43f5e) !important;
color: white;
font-weight: 600;
}
/* Tab styling */
.tab-label {
font-size: 1.15rem;
font-weight: 600;
}
"""
) as demo:
gr.Markdown("# πŸ“° News Classifier & Question Answering App..")
with gr.Tabs():
with gr.Tab("πŸ“Š News Classification"):
gr.Markdown("### Upload CSV and get automatic category prediction")
file_input = gr.File(
label="πŸ“€ Upload your CSV file",
file_types=[".csv"],
height=160
)
classify_btn = gr.Button("πŸš€ Classify News", variant="primary", size="large")
with gr.Row():
output_text = gr.Textbox(label="Status", scale=2)
output_file = gr.File(label="πŸ“₯ Download output.csv")
bar_chart = gr.Plot(label="πŸ“Š Category Distribution Across 5 Classes")
classify_btn.click(
fn=classify_csv,
inputs=file_input,
outputs=[output_text, output_file, bar_chart]
)
with gr.Tab("❓ Question Answering"):
gr.Markdown("### Ask any question about a news article")
news_input = gr.Textbox(lines=12, label="πŸ“ Paste News Content", placeholder="Paste the full news article here...")
question_input = gr.Textbox(label="❓ Your Question", placeholder="e.g. What is the main topic?")
qa_btn = gr.Button("πŸ” Get Answer", variant="primary", size="large")
qa_output = gr.Textbox(label="πŸ’‘ Answer", lines=6)
qa_btn.click(
fn=answer_question,
inputs=[news_input, question_input],
outputs=qa_output
)
gr.Markdown("---")
demo.launch()