Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from transformers import pipeline | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.stem import WordNetLemmatizer | |
| import re | |
| import string | |
| import matplotlib.pyplot as plt | |
| # ====================== NLTK SETUP ====================== | |
| nltk.download('wordnet', quiet=True) | |
| nltk.download('punkt', quiet=True) | |
| nltk.download('punkt_tab', quiet=True) | |
| lemmatizer = WordNetLemmatizer() | |
| def preprocess_text(text): | |
| if not isinstance(text, str): | |
| return "" | |
| text = text.lower() | |
| punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","") | |
| text = re.sub(f"[{punct_to_remove}]", " ", text) | |
| tokens = nltk.word_tokenize(text) | |
| tokens = [lemmatizer.lemmatize(word) for word in tokens] | |
| return ' '.join(tokens) | |
| classifier_model = "Ginidu2003/Distilbert-Base-News-classifier" | |
| # ====================== BEAUTIFUL COLORED BAR CHART ====================== | |
| def create_colored_bar_chart(category_counts): | |
| if category_counts is None or len(category_counts) == 0: | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, "No data available", ha='center', va='center') | |
| return fig | |
| categories = category_counts["Category"] | |
| counts = category_counts["Count"] | |
| # Nice modern color palette | |
| colors = ['#3498DB', '#E67E22', '#9B59B6', '#2ECC71', '#E74C3C'] | |
| fig, ax = plt.subplots(figsize=(11, 6)) | |
| bars = ax.bar(categories, counts, color=colors, edgecolor='white', linewidth=0.8) | |
| # Add value on top of bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2, height + 0.8, | |
| str(int(height)), ha='center', va='bottom', fontsize=13, fontweight='bold') | |
| ax.set_title("Category Distribution Across 5 Classes", fontsize=16, fontweight='bold', pad=20) | |
| ax.set_xlabel("Category", fontsize=12) | |
| ax.set_ylabel("Count", fontsize=12) | |
| plt.xticks(rotation=15) | |
| plt.tight_layout() | |
| return fig | |
| # ====================== CLASSIFICATION FUNCTION ====================== | |
| def classify_csv(file): | |
| try: | |
| df = pd.read_csv(file) | |
| if 'content' not in df.columns: | |
| return "Error: CSV must have a column named 'content'", None, None | |
| df['clean_content'] = df['content'].apply(preprocess_text) | |
| classifier = pipeline("text-classification", model=classifier_model, device=-1) | |
| predictions = [] | |
| for text in df['clean_content']: | |
| if not text.strip(): | |
| predictions.append("Unknown") | |
| else: | |
| result = classifier(text)[0] | |
| predictions.append(result['label']) | |
| df['class'] = predictions | |
| df = df.drop(columns=['clean_content'], errors='ignore') | |
| output_file = "output.csv" | |
| df.to_csv(output_file, index=False) | |
| category_counts = df['class'].value_counts().reset_index() | |
| category_counts.columns = ["Category", "Count"] | |
| fig = create_colored_bar_chart(category_counts) | |
| return f"β Success! Classified {len(df)} rows", output_file, fig | |
| except Exception as e: | |
| return f"β Error: {str(e)}", None, None | |
| # ====================== Q&A FUNCTION ====================== | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") | |
| qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") | |
| def answer_question(news_content, question): | |
| if not news_content.strip() or not question.strip(): | |
| return "Please enter both news content and a question." | |
| try: | |
| inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = qa_model(**inputs) | |
| start_idx = torch.argmax(outputs.start_logits) | |
| end_idx = torch.argmax(outputs.end_logits) + 1 | |
| answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True) | |
| confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item() | |
| return f"**Answer:** :- {answer.strip()}\n\n**Confidence:** :- {confidence:.2%}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ====================== BEAUTIFUL UI ====================== | |
| with gr.Blocks( | |
| title="News Classifier & Question Answering App", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1250px; | |
| min-width: 1000px; | |
| margin: auto; | |
| background: linear-gradient(135deg, #0f172a 0%, #1e2937 100%); | |
| } | |
| h1 { | |
| text-align: center; | |
| font-size: 2.9rem; | |
| background: linear-gradient(90deg, #60a5fa, #c084fc); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| margin-bottom: 10px; | |
| } | |
| /* Upload Box - Gradient */ | |
| .file-upload { | |
| background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important; | |
| border-radius: 16px; | |
| border: none; | |
| } | |
| /* Classify News Button - Gradient */ | |
| button.primary { | |
| background: linear-gradient(90deg, #6366f1, #a855f7) !important; | |
| border: none; | |
| font-weight: 700; | |
| font-size: 1.1rem; | |
| padding: 14px 0; | |
| border-radius: 12px; | |
| transition: all 0.3s ease; | |
| } | |
| button.primary:hover { | |
| transform: translateY(-3px); | |
| box-shadow: 0 15px 25px rgba(139, 92, 246, 0.5); | |
| } | |
| /* Status Box - Gradient */ | |
| .status-box { | |
| background: linear-gradient(135deg, #10b981, #34d399) !important; | |
| color: white; | |
| border-radius: 12px; | |
| } | |
| /* Download Button - Gradient */ | |
| button.secondary { | |
| background: linear-gradient(90deg, #ec4899, #f43f5e) !important; | |
| color: white; | |
| font-weight: 600; | |
| } | |
| /* Tab styling */ | |
| .tab-label { | |
| font-size: 1.15rem; | |
| font-weight: 600; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown("# π° News Classifier & Question Answering App..") | |
| with gr.Tabs(): | |
| with gr.Tab("π News Classification"): | |
| gr.Markdown("### Upload CSV and get automatic category prediction") | |
| file_input = gr.File( | |
| label="π€ Upload your CSV file", | |
| file_types=[".csv"], | |
| height=160 | |
| ) | |
| classify_btn = gr.Button("π Classify News", variant="primary", size="large") | |
| with gr.Row(): | |
| output_text = gr.Textbox(label="Status", scale=2) | |
| output_file = gr.File(label="π₯ Download output.csv") | |
| bar_chart = gr.Plot(label="π Category Distribution Across 5 Classes") | |
| classify_btn.click( | |
| fn=classify_csv, | |
| inputs=file_input, | |
| outputs=[output_text, output_file, bar_chart] | |
| ) | |
| with gr.Tab("β Question Answering"): | |
| gr.Markdown("### Ask any question about a news article") | |
| news_input = gr.Textbox(lines=12, label="π Paste News Content", placeholder="Paste the full news article here...") | |
| question_input = gr.Textbox(label="β Your Question", placeholder="e.g. What is the main topic?") | |
| qa_btn = gr.Button("π Get Answer", variant="primary", size="large") | |
| qa_output = gr.Textbox(label="π‘ Answer", lines=6) | |
| qa_btn.click( | |
| fn=answer_question, | |
| inputs=[news_input, question_input], | |
| outputs=qa_output | |
| ) | |
| gr.Markdown("---") | |
| demo.launch() |