Spaces:
Sleeping
Sleeping
File size: 7,892 Bytes
11a25f6 24d4d93 deac599 642e206 7ab03c4 deac599 24d4d93 7ab03c4 9e26749 24d4d93 7ab03c4 24d4d93 7ab03c4 7541cb1 24d4d93 9e26749 7ab03c4 9e26749 7ab03c4 9e26749 7541cb1 7ab03c4 9e26749 7ab03c4 7541cb1 9e26749 7ab03c4 9e26749 7ab03c4 9e26749 7ab03c4 9e26749 7541cb1 9e26749 7ab03c4 7541cb1 7ab03c4 7541cb1 7ab03c4 deac599 9e26749 d3a92b9 e116561 f4bbf00 deac599 fec7c1e 24d4d93 7ab03c4 24d4d93 deac599 a763a4d deac599 e116561 fec7c1e e116561 deac599 e116561 d3a92b9 4131c88 deac599 24d4d93 deac599 d3a92b9 f758c9c 0305bff f758c9c 945f42c deac599 9e26749 945f42c fec7c1e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | import gradio as gr
import pandas as pd
import torch
from transformers import pipeline
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re
import string
import matplotlib.pyplot as plt
# ====================== NLTK SETUP ======================
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
if not isinstance(text, str):
return ""
text = text.lower()
punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
text = re.sub(f"[{punct_to_remove}]", " ", text)
tokens = nltk.word_tokenize(text)
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
# ====================== BEAUTIFUL COLORED BAR CHART ======================
def create_colored_bar_chart(category_counts):
if category_counts is None or len(category_counts) == 0:
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
return fig
categories = category_counts["Category"]
counts = category_counts["Count"]
# Nice modern color palette
colors = ['#3498DB', '#E67E22', '#9B59B6', '#2ECC71', '#E74C3C']
fig, ax = plt.subplots(figsize=(11, 6))
bars = ax.bar(categories, counts, color=colors, edgecolor='white', linewidth=0.8)
# Add value on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, height + 0.8,
str(int(height)), ha='center', va='bottom', fontsize=13, fontweight='bold')
ax.set_title("Category Distribution Across 5 Classes", fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel("Category", fontsize=12)
ax.set_ylabel("Count", fontsize=12)
plt.xticks(rotation=15)
plt.tight_layout()
return fig
# ====================== CLASSIFICATION FUNCTION ======================
@torch.no_grad()
def classify_csv(file):
try:
df = pd.read_csv(file)
if 'content' not in df.columns:
return "Error: CSV must have a column named 'content'", None, None
df['clean_content'] = df['content'].apply(preprocess_text)
classifier = pipeline("text-classification", model=classifier_model, device=-1)
predictions = []
for text in df['clean_content']:
if not text.strip():
predictions.append("Unknown")
else:
result = classifier(text)[0]
predictions.append(result['label'])
df['class'] = predictions
df = df.drop(columns=['clean_content'], errors='ignore')
output_file = "output.csv"
df.to_csv(output_file, index=False)
category_counts = df['class'].value_counts().reset_index()
category_counts.columns = ["Category", "Count"]
fig = create_colored_bar_chart(category_counts)
return f"β
Success! Classified {len(df)} rows", output_file, fig
except Exception as e:
return f"β Error: {str(e)}", None, None
# ====================== Q&A FUNCTION ======================
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
def answer_question(news_content, question):
if not news_content.strip() or not question.strip():
return "Please enter both news content and a question."
try:
inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = qa_model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits) + 1
answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
return f"**Answer:** {answer.strip()}\n\n**Confidence:** {confidence:.2%}"
except Exception as e:
return f"Error: {str(e)}"
# ====================== BEAUTIFUL UI ======================
with gr.Blocks(
title="News Classifier & Question Answering App",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1250px;
min-width: 1000px;
margin: auto;
background: linear-gradient(135deg, #0f172a 0%, #1e2937 100%);
}
h1 {
text-align: center;
font-size: 2.9rem;
background: linear-gradient(90deg, #60a5fa, #c084fc);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 10px;
}
/* Upload Box - Gradient */
.file-upload {
background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important;
border-radius: 16px;
border: none;
}
/* Classify News Button - Gradient */
button.primary {
background: linear-gradient(90deg, #6366f1, #a855f7) !important;
border: none;
font-weight: 700;
font-size: 1.1rem;
padding: 14px 0;
border-radius: 12px;
transition: all 0.3s ease;
}
button.primary:hover {
transform: translateY(-3px);
box-shadow: 0 15px 25px rgba(139, 92, 246, 0.5);
}
/* Status Box - Gradient */
.status-box {
background: linear-gradient(135deg, #10b981, #34d399) !important;
color: white;
border-radius: 12px;
}
/* Download Button - Gradient */
button.secondary {
background: linear-gradient(90deg, #ec4899, #f43f5e) !important;
color: white;
font-weight: 600;
}
/* Tab styling */
.tab-label {
font-size: 1.15rem;
font-weight: 600;
}
"""
) as demo:
gr.Markdown("# π° News Classifier & Question Answering App..")
with gr.Tabs():
with gr.Tab("π News Classification"):
gr.Markdown("### Upload CSV and get automatic category prediction")
file_input = gr.File(
label="π€ Upload your CSV file",
file_types=[".csv"],
height=160
)
classify_btn = gr.Button("π Classify News", variant="primary", size="large")
with gr.Row():
output_text = gr.Textbox(label="Status", scale=2)
output_file = gr.File(label="π₯ Download output.csv")
bar_chart = gr.Plot(label="π Category Distribution Across 5 Classes")
classify_btn.click(
fn=classify_csv,
inputs=file_input,
outputs=[output_text, output_file, bar_chart]
)
with gr.Tab("β Question Answering"):
gr.Markdown("### Ask any question about a news article")
news_input = gr.Textbox(lines=12, label="π Paste News Content", placeholder="Paste the full news article here...")
question_input = gr.Textbox(label="β Your Question", placeholder="e.g. What is the main topic?")
qa_btn = gr.Button("π Get Answer", variant="primary", size="large")
qa_output = gr.Textbox(label="π‘ Answer", lines=6)
qa_btn.click(
fn=answer_question,
inputs=[news_input, question_input],
outputs=qa_output
)
gr.Markdown("---")
demo.launch() |