Spaces:
Sleeping
Sleeping
File size: 7,900 Bytes
945f42c a480e6b 945f42c d3a92b9 3007158 d3a92b9 0305bff 3007158 c269a58 3007158 c269a58 3007158 945f42c c269a58 cc8f814 c269a58 a04138a c269a58 92f17aa c269a58 d487881 fec7c1e 945f42c fec7c1e 945f42c c269a58 fec7c1e 3007158 fec7c1e d3a92b9 fec7c1e 945f42c 3007158 945f42c fec7c1e 945f42c fec7c1e 945f42c c269a58 8ac79b0 c269a58 a480e6b 945f42c 8ac79b0 a480e6b d3a92b9 cc8f814 c269a58 d3a92b9 f153c22 0305bff d3a92b9 92f17aa cc8f814 92f17aa c269a58 cc8f814 92f17aa cc8f814 620df90 339e6ca d3a92b9 0305bff d3a92b9 c269a58 a04138a d3a92b9 e116561 f4bbf00 c269a58 8ac79b0 fec7c1e c269a58 e116561 fec7c1e e116561 c269a58 e116561 d3a92b9 4131c88 c269a58 d3a92b9 f758c9c 0305bff f758c9c 945f42c c269a58 a04138a 945f42c fec7c1e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | import gradio as gr
import pandas as pd
import torch
from transformers import pipeline
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re
import string
import matplotlib.pyplot as plt
# ====================== NLTK SETUP ======================
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
if not isinstance(text, str):
return ""
text = text.lower()
punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
text = re.sub(f"[{punct_to_remove}]", " ", text)
tokens = nltk.word_tokenize(text)
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
# ====================== BEAUTIFUL COLORED BAR CHART ======================
def create_colored_bar_chart(category_counts):
if category_counts is None or len(category_counts) == 0:
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
return fig
categories = category_counts["Category"]
counts = category_counts["Count"]
# Nice modern color palette
colors = ['#3498DB', '#E67E22', '#9B59B6', '#2ECC71', '#E74C3C']
fig, ax = plt.subplots(figsize=(11, 6))
bars = ax.bar(categories, counts, color=colors, edgecolor='white', linewidth=0.8)
# Add value on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, height + 0.8,
str(int(height)), ha='center', va='bottom', fontsize=13, fontweight='bold')
ax.set_title("Category Distribution Across 5 Classes", fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel("Category", fontsize=12)
ax.set_ylabel("Count", fontsize=12)
plt.xticks(rotation=15)
plt.tight_layout()
return fig
# ====================== CLASSIFICATION FUNCTION ======================
@torch.no_grad()
def classify_csv(file):
try:
df = pd.read_csv(file)
if 'content' not in df.columns:
return "Error: CSV must have a column named 'content'", None, None
df['clean_content'] = df['content'].apply(preprocess_text)
classifier = pipeline("text-classification", model=classifier_model, device=-1)
predictions = []
for text in df['clean_content']:
if not text.strip():
predictions.append("Unknown")
else:
result = classifier(text)[0]
predictions.append(result['label'])
df['class'] = predictions
df = df.drop(columns=['clean_content'], errors='ignore')
output_file = "output.csv"
df.to_csv(output_file, index=False)
category_counts = df['class'].value_counts().reset_index()
category_counts.columns = ["Category", "Count"]
fig = create_colored_bar_chart(category_counts)
return f"β
Success! Classified {len(df)} rows", output_file, fig
except Exception as e:
return f"β Error: {str(e)}", None, None
# ====================== Q&A FUNCTION ======================
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
def answer_question(news_content, question):
if not news_content.strip() or not question.strip():
return "Please enter both news content and a question."
try:
inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = qa_model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits) + 1
answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
return f"**Answer:** :- {answer.strip()}\n\n**Confidence:** :- {confidence:.2%}"
except Exception as e:
return f"Error: {str(e)}"
# ====================== BEAUTIFUL UI ======================
with gr.Blocks(
title="News Classifier & Question Answering App",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1250px;
min-width: 1000px;
margin: auto;
background: linear-gradient(135deg, #0f172a 0%, #1e2937 100%);
}
h1 {
text-align: center;
font-size: 2.9rem;
background: linear-gradient(90deg, #60a5fa, #c084fc);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 10px;
}
/* Upload Box - Gradient */
.file-upload {
background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important;
border-radius: 16px;
border: none;
}
/* Classify News Button - Gradient */
button.primary {
background: linear-gradient(90deg, #6366f1, #a855f7) !important;
border: none;
font-weight: 700;
font-size: 1.1rem;
padding: 14px 0;
border-radius: 12px;
transition: all 0.3s ease;
}
button.primary:hover {
transform: translateY(-3px);
box-shadow: 0 15px 25px rgba(139, 92, 246, 0.5);
}
/* Status Box - Gradient */
.status-box {
background: linear-gradient(135deg, #10b981, #34d399) !important;
color: white;
border-radius: 12px;
}
/* Download Button - Gradient */
button.secondary {
background: linear-gradient(90deg, #ec4899, #f43f5e) !important;
color: white;
font-weight: 600;
}
/* Tab styling */
.tab-label {
font-size: 1.15rem;
font-weight: 600;
}
"""
) as demo:
gr.Markdown("# π° News Classifier & Question Answering App..")
with gr.Tabs():
with gr.Tab("π News Classification"):
gr.Markdown("### Upload CSV and get automatic category prediction")
file_input = gr.File(
label="π€ Upload your CSV file",
file_types=[".csv"],
height=160
)
classify_btn = gr.Button("π Classify News", variant="primary", size="large")
with gr.Row():
output_text = gr.Textbox(label="Status", scale=2)
output_file = gr.File(label="π₯ Download output.csv")
bar_chart = gr.Plot(label="π Category Distribution Across 5 Classes")
classify_btn.click(
fn=classify_csv,
inputs=file_input,
outputs=[output_text, output_file, bar_chart]
)
with gr.Tab("β Question Answering"):
gr.Markdown("### Ask any question about a news article")
news_input = gr.Textbox(lines=12, label="π Paste News Content", placeholder="Paste the full news article here...")
question_input = gr.Textbox(label="β Your Question", placeholder="e.g. What is the main topic?")
qa_btn = gr.Button("π Get Answer", variant="primary", size="large")
qa_output = gr.Textbox(label="π‘ Answer", lines=6)
qa_btn.click(
fn=answer_question,
inputs=[news_input, question_input],
outputs=qa_output
)
gr.Markdown("---")
demo.launch() |