Ginidu2003 commited on
Commit
c269a58
Β·
verified Β·
1 Parent(s): 2e79406

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -67
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import pandas as pd
4
  import torch
@@ -11,36 +10,60 @@ import string
11
  import matplotlib.pyplot as plt
12
 
13
  # ====================== NLTK SETUP ======================
14
-
15
  nltk.download('wordnet', quiet=True)
16
  nltk.download('punkt', quiet=True)
17
  nltk.download('punkt_tab', quiet=True)
18
 
19
  lemmatizer = WordNetLemmatizer()
20
- # ============= Preprocessing==============================
21
  def preprocess_text(text):
22
  if not isinstance(text, str):
23
  return ""
24
  text = text.lower()
25
  punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
26
- text = re.sub(f"[{punct_to_remove}]", " ", text)# Remove punctuation except: ' " $ % ?
27
  tokens = nltk.word_tokenize(text)
28
- tokens = [word for word in tokens]
29
  tokens = [lemmatizer.lemmatize(word) for word in tokens]
30
  return ' '.join(tokens)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
33
 
 
 
 
 
 
34
 
 
 
 
 
 
 
35
 
36
  # ====================== CLASSIFICATION FUNCTION ======================
37
- classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
38
  @torch.no_grad()
39
  def classify_csv(file):
40
  try:
41
  df = pd.read_csv(file)
42
  if 'content' not in df.columns:
43
- return "Error: CSV must have a column named 'content'", None
44
 
45
  df['clean_content'] = df['content'].apply(preprocess_text)
46
 
@@ -59,116 +82,142 @@ def classify_csv(file):
59
 
60
  output_file = "output.csv"
61
  df.to_csv(output_file, index=False)
62
-
63
- # Count categories
64
  category_counts = df['class'].value_counts().reset_index()
65
  category_counts.columns = ["Category", "Count"]
66
-
67
- # Create colored bar chart
68
  fig = create_colored_bar_chart(category_counts)
69
 
70
  return f"βœ… Success! Classified {len(df)} rows", output_file, fig
71
  except Exception as e:
72
  return f"❌ Error: {str(e)}", None, None
73
 
74
- # ====================== COLORED BAR CHART ======================
75
- def create_colored_bar_chart(category_counts):
76
- if category_counts is None or len(category_counts) == 0:
77
- fig, ax = plt.subplots()
78
- ax.text(0.5, 0.5, "No data available", ha='center', va='center')
79
- return fig
80
-
81
- categories = category_counts["Category"]
82
- counts = category_counts["Count"]
83
-
84
- # Different attractive colors for each category
85
- colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD']
86
-
87
- fig, ax = plt.subplots(figsize=(10, 6))
88
- bars = ax.bar(categories, counts, color=colors)
89
-
90
- # Add count numbers on top of bars
91
- for bar in bars:
92
- height = bar.get_height()
93
- ax.text(bar.get_x() + bar.get_width()/2, height + 0.5,
94
- str(int(height)), ha='center', va='bottom', fontsize=12, fontweight='bold')
95
-
96
- ax.set_title("News Category Distribution Across 5 Classes", fontsize=14, fontweight='bold')
97
- ax.set_xlabel("Category")
98
- ax.set_ylabel("Count")
99
- plt.xticks(rotation=15)
100
- plt.tight_layout()
101
- return fig
102
-
103
-
104
-
105
  # ====================== Q&A FUNCTION ======================
106
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
107
-
108
  qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
109
  qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
 
110
  def answer_question(news_content, question):
111
  if not news_content.strip() or not question.strip():
112
  return "Please enter both news content and a question."
113
-
114
  try:
115
  inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
116
-
117
  with torch.no_grad():
118
  outputs = qa_model(**inputs)
119
 
120
  start_idx = torch.argmax(outputs.start_logits)
121
  end_idx = torch.argmax(outputs.end_logits) + 1
122
 
123
- # Clean answer - remove question repetition and special tokens
124
- answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
125
- skip_special_tokens=True,
126
  clean_up_tokenization_spaces=True)
127
 
128
  confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
129
 
130
  return f"**Answer:** {answer.strip()}\n\n**Confidence:** {confidence:.2%}"
131
-
132
  except Exception as e:
133
  return f"Error: {str(e)}"
134
 
135
- # ====================== GRADIO INTERFACE ======================
136
- with gr.Blocks(title="News Classifier & Question Answering App..") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  gr.Markdown("# πŸ“° News Classifier & Question Answering App..")
138
 
139
 
140
  with gr.Tabs():
141
  with gr.Tab("πŸ“Š News Classification"):
142
- gr.Markdown("Upload CSV with `content` column")
143
- file_input = gr.File(label="Upload CSV", file_types=[".csv"])
144
- classify_btn = gr.Button("πŸš€ Classify News", variant="primary")
145
- output_text = gr.Textbox(label="Status")
146
- output_file = gr.File(label="Download output.csv")
147
- bar_chart = gr.Plot(
148
- label="News Category Distribution Across 5 Classes"
149
-
150
-
151
  )
152
 
 
 
 
 
 
 
 
 
153
  classify_btn.click(
154
  fn=classify_csv,
155
  inputs=file_input,
156
- outputs=[output_text, output_file,bar_chart]
157
  )
158
 
159
  with gr.Tab("❓ Question Answering"):
160
- gr.Markdown("Ask any question about a news article")
161
- news_input = gr.Textbox(lines=12, label="Paste News Content", placeholder="Paste the full news article here...")
162
- question_input = gr.Textbox(label="Your Question", placeholder="e.g. What is the main topic?")
163
- qa_btn = gr.Button("πŸ” Get Answer", variant="primary")
164
- qa_output = gr.Textbox(label="Answer", lines=5)
165
-
166
  qa_btn.click(
167
  fn=answer_question,
168
  inputs=[news_input, question_input],
169
  outputs=qa_output
170
  )
171
 
 
172
 
173
 
174
  demo.launch()
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
 
10
  import matplotlib.pyplot as plt
11
 
12
  # ====================== NLTK SETUP ======================
 
13
  nltk.download('wordnet', quiet=True)
14
  nltk.download('punkt', quiet=True)
15
  nltk.download('punkt_tab', quiet=True)
16
 
17
  lemmatizer = WordNetLemmatizer()
18
+
19
  def preprocess_text(text):
20
  if not isinstance(text, str):
21
  return ""
22
  text = text.lower()
23
  punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
24
+ text = re.sub(f"[{punct_to_remove}]", " ", text)
25
  tokens = nltk.word_tokenize(text)
 
26
  tokens = [lemmatizer.lemmatize(word) for word in tokens]
27
  return ' '.join(tokens)
28
 
29
+ classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
30
+
31
+ # ====================== BEAUTIFUL COLORED BAR CHART ======================
32
+ def create_colored_bar_chart(category_counts):
33
+ if category_counts is None or len(category_counts) == 0:
34
+ fig, ax = plt.subplots()
35
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
36
+ return fig
37
+
38
+ categories = category_counts["Category"]
39
+ counts = category_counts["Count"]
40
+
41
+ # Nice modern color palette
42
+ colors = ['#3498DB', '#E67E22', '#9B59B6', '#2ECC71', '#E74C3C']
43
 
44
+ fig, ax = plt.subplots(figsize=(11, 6))
45
+ bars = ax.bar(categories, counts, color=colors, edgecolor='white', linewidth=0.8)
46
 
47
+ # Add value on top of bars
48
+ for bar in bars:
49
+ height = bar.get_height()
50
+ ax.text(bar.get_x() + bar.get_width()/2, height + 0.8,
51
+ str(int(height)), ha='center', va='bottom', fontsize=13, fontweight='bold')
52
 
53
+ ax.set_title("Category Distribution Across 5 Classes", fontsize=16, fontweight='bold', pad=20)
54
+ ax.set_xlabel("Category", fontsize=12)
55
+ ax.set_ylabel("Count", fontsize=12)
56
+ plt.xticks(rotation=15)
57
+ plt.tight_layout()
58
+ return fig
59
 
60
  # ====================== CLASSIFICATION FUNCTION ======================
 
61
  @torch.no_grad()
62
  def classify_csv(file):
63
  try:
64
  df = pd.read_csv(file)
65
  if 'content' not in df.columns:
66
+ return "Error: CSV must have a column named 'content'", None, None
67
 
68
  df['clean_content'] = df['content'].apply(preprocess_text)
69
 
 
82
 
83
  output_file = "output.csv"
84
  df.to_csv(output_file, index=False)
85
+
 
86
  category_counts = df['class'].value_counts().reset_index()
87
  category_counts.columns = ["Category", "Count"]
88
+
 
89
  fig = create_colored_bar_chart(category_counts)
90
 
91
  return f"βœ… Success! Classified {len(df)} rows", output_file, fig
92
  except Exception as e:
93
  return f"❌ Error: {str(e)}", None, None
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # ====================== Q&A FUNCTION ======================
96
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
 
97
  qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
98
  qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
99
+
100
  def answer_question(news_content, question):
101
  if not news_content.strip() or not question.strip():
102
  return "Please enter both news content and a question."
 
103
  try:
104
  inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
 
105
  with torch.no_grad():
106
  outputs = qa_model(**inputs)
107
 
108
  start_idx = torch.argmax(outputs.start_logits)
109
  end_idx = torch.argmax(outputs.end_logits) + 1
110
 
111
+ answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
112
+ skip_special_tokens=True,
 
113
  clean_up_tokenization_spaces=True)
114
 
115
  confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
116
 
117
  return f"**Answer:** {answer.strip()}\n\n**Confidence:** {confidence:.2%}"
 
118
  except Exception as e:
119
  return f"Error: {str(e)}"
120
 
121
+ # ====================== BEAUTIFUL UI ======================
122
+ with gr.Blocks(
123
+ title="News Classifier & Question Answering App",
124
+ theme=gr.themes.Soft(),
125
+ css="""
126
+ .gradio-container {
127
+ max-width: 1250px;
128
+ min-width: 1000px;
129
+ margin: auto;
130
+ background: linear-gradient(135deg, #0f172a 0%, #1e2937 100%);
131
+ }
132
+ h1 {
133
+ text-align: center;
134
+ font-size: 2.9rem;
135
+ background: linear-gradient(90deg, #60a5fa, #c084fc);
136
+ -webkit-background-clip: text;
137
+ -webkit-text-fill-color: transparent;
138
+ margin-bottom: 10px;
139
+ }
140
+ /* Upload Box - Gradient */
141
+ .file-upload {
142
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important;
143
+ border-radius: 16px;
144
+ border: none;
145
+ }
146
+ /* Classify News Button - Gradient */
147
+ button.primary {
148
+ background: linear-gradient(90deg, #6366f1, #a855f7) !important;
149
+ border: none;
150
+ font-weight: 700;
151
+ font-size: 1.1rem;
152
+ padding: 14px 0;
153
+ border-radius: 12px;
154
+ transition: all 0.3s ease;
155
+ }
156
+ button.primary:hover {
157
+ transform: translateY(-3px);
158
+ box-shadow: 0 15px 25px rgba(139, 92, 246, 0.5);
159
+ }
160
+ /* Status Box - Gradient */
161
+ .status-box {
162
+ background: linear-gradient(135deg, #10b981, #34d399) !important;
163
+ color: white;
164
+ border-radius: 12px;
165
+ }
166
+ /* Download Button - Gradient */
167
+ button.secondary {
168
+ background: linear-gradient(90deg, #ec4899, #f43f5e) !important;
169
+ color: white;
170
+ font-weight: 600;
171
+ }
172
+ /* Tab styling */
173
+ .tab-label {
174
+ font-size: 1.15rem;
175
+ font-weight: 600;
176
+ }
177
+ """
178
+ ) as demo:
179
+
180
  gr.Markdown("# πŸ“° News Classifier & Question Answering App..")
181
 
182
 
183
  with gr.Tabs():
184
  with gr.Tab("πŸ“Š News Classification"):
185
+ gr.Markdown("### Upload CSV and get automatic category prediction")
186
+
187
+ file_input = gr.File(
188
+ label="πŸ“€ Upload your CSV file",
189
+ file_types=[".csv"],
190
+ height=160
 
 
 
191
  )
192
 
193
+ classify_btn = gr.Button("πŸš€ Classify News", variant="primary", size="large")
194
+
195
+ with gr.Row():
196
+ output_text = gr.Textbox(label="Status", scale=2)
197
+ output_file = gr.File(label="πŸ“₯ Download output.csv")
198
+
199
+ bar_chart = gr.Plot(label="πŸ“Š Category Distribution Across 5 Classes")
200
+
201
  classify_btn.click(
202
  fn=classify_csv,
203
  inputs=file_input,
204
+ outputs=[output_text, output_file, bar_chart]
205
  )
206
 
207
  with gr.Tab("❓ Question Answering"):
208
+ gr.Markdown("### Ask any question about a news article")
209
+ news_input = gr.Textbox(lines=12, label="πŸ“ Paste News Content", placeholder="Paste the full news article here...")
210
+ question_input = gr.Textbox(label="❓ Your Question", placeholder="e.g. What is the main topic?")
211
+ qa_btn = gr.Button("πŸ” Get Answer", variant="primary", size="large")
212
+ qa_output = gr.Textbox(label="πŸ’‘ Answer", lines=6)
213
+
214
  qa_btn.click(
215
  fn=answer_question,
216
  inputs=[news_input, question_input],
217
  outputs=qa_output
218
  )
219
 
220
+ gr.Markdown("---")
221
 
222
 
223
  demo.launch()