Ginidu2003 commited on
Commit
deac599
Β·
verified Β·
1 Parent(s): c3cf5b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -69
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import pandas as pd
4
  import torch
@@ -11,12 +10,10 @@ import string
11
  import matplotlib.pyplot as plt
12
 
13
  # ====================== NLTK SETUP ======================
14
-
15
  nltk.download('wordnet', quiet=True)
16
  nltk.download('punkt', quiet=True)
17
  nltk.download('punkt_tab', quiet=True)
18
 
19
-
20
  lemmatizer = WordNetLemmatizer()
21
 
22
  def preprocess_text(text):
@@ -24,17 +21,39 @@ def preprocess_text(text):
24
  return ""
25
  text = text.lower()
26
  punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
27
- text = re.sub(f"[{punct_to_remove}]", " ", text)# Remove punctuation except: ' " $ % ?
28
  tokens = nltk.word_tokenize(text)
29
- tokens = [word for word in tokens]
30
  tokens = [lemmatizer.lemmatize(word) for word in tokens]
31
  return ' '.join(tokens)
32
 
33
-
34
  classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
35
 
 
 
 
 
 
 
 
 
 
36
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # ====================== CLASSIFICATION FUNCTION ======================
40
  @torch.no_grad()
@@ -42,7 +61,7 @@ def classify_csv(file):
42
  try:
43
  df = pd.read_csv(file)
44
  if 'content' not in df.columns:
45
- return "Error: CSV must have a column named 'content'", None
46
 
47
  df['clean_content'] = df['content'].apply(preprocess_text)
48
 
@@ -61,116 +80,106 @@ def classify_csv(file):
61
 
62
  output_file = "output.csv"
63
  df.to_csv(output_file, index=False)
64
-
65
- # Count categories
66
  category_counts = df['class'].value_counts().reset_index()
67
  category_counts.columns = ["Category", "Count"]
68
-
69
- # Create colored bar chart
70
  fig = create_colored_bar_chart(category_counts)
71
 
72
  return f"βœ… Success! Classified {len(df)} rows", output_file, fig
73
  except Exception as e:
74
  return f"❌ Error: {str(e)}", None, None
75
 
76
- # ====================== COLORED BAR CHART ======================
77
- def create_colored_bar_chart(category_counts):
78
- if category_counts is None or len(category_counts) == 0:
79
- fig, ax = plt.subplots()
80
- ax.text(0.5, 0.5, "No data available", ha='center', va='center')
81
- return fig
82
-
83
- categories = category_counts["Category"]
84
- counts = category_counts["Count"]
85
-
86
- # Different attractive colors for each category
87
- colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD']
88
-
89
- fig, ax = plt.subplots(figsize=(10, 6))
90
- bars = ax.bar(categories, counts, color=colors)
91
-
92
- # Add count numbers on top of bars
93
- for bar in bars:
94
- height = bar.get_height()
95
- ax.text(bar.get_x() + bar.get_width()/2, height + 0.5,
96
- str(int(height)), ha='center', va='bottom', fontsize=12, fontweight='bold')
97
-
98
- ax.set_title("Category Distribution Across 5 Classes", fontsize=14, fontweight='bold')
99
- ax.set_xlabel("Category")
100
- ax.set_ylabel("Count")
101
- plt.xticks(rotation=15)
102
- plt.tight_layout()
103
- return fig
104
-
105
-
106
-
107
  # ====================== Q&A FUNCTION ======================
108
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
109
-
110
  qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
111
  qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
 
112
  def answer_question(news_content, question):
113
  if not news_content.strip() or not question.strip():
114
  return "Please enter both news content and a question."
115
-
116
  try:
117
  inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
118
-
119
  with torch.no_grad():
120
  outputs = qa_model(**inputs)
121
 
122
  start_idx = torch.argmax(outputs.start_logits)
123
  end_idx = torch.argmax(outputs.end_logits) + 1
124
 
125
- # Clean answer - remove question repetition and special tokens
126
- answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
127
- skip_special_tokens=True,
128
  clean_up_tokenization_spaces=True)
129
 
130
  confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
131
 
132
  return f"**Answer:** {answer.strip()}\n\n**Confidence:** {confidence:.2%}"
133
-
134
  except Exception as e:
135
  return f"Error: {str(e)}"
136
 
137
- # ====================== GRADIO INTERFACE ======================
138
- with gr.Blocks(title=" News Classifier & Question Answering App") as demo:
 
 
 
 
 
 
 
 
 
139
  gr.Markdown("# πŸ“° English News Classifier")
140
- #gr.Markdown("### Section 02 - Text Analytics Assignment")
141
 
142
  with gr.Tabs():
 
143
  with gr.Tab("πŸ“Š News Classification"):
144
- gr.Markdown("Upload CSV with `content` column")
145
- file_input = gr.File(label="Upload CSV", file_types=[".csv"])
146
- classify_btn = gr.Button("πŸš€ Classify News", variant="primary")
147
- output_text = gr.Textbox(label="Status")
148
- output_file = gr.File(label="Download output.csv")
149
- bar_chart = gr.Plot(
150
- label="Category Distribution Across 5 Classes"
151
-
152
-
153
- )
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  classify_btn.click(
156
  fn=classify_csv,
157
  inputs=file_input,
158
- outputs=[output_text, output_file,bar_chart]
159
  )
160
 
 
161
  with gr.Tab("❓ Question Answering"):
162
- gr.Markdown("Ask any question about a news article")
163
- news_input = gr.Textbox(lines=12, label="Paste News Content", placeholder="Paste the full news article here...")
164
- question_input = gr.Textbox(label="Your Question", placeholder="e.g. What is the main topic?")
165
- qa_btn = gr.Button("πŸ” Get Answer", variant="primary")
166
- qa_output = gr.Textbox(label="Answer", lines=5)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  qa_btn.click(
169
  fn=answer_question,
170
  inputs=[news_input, question_input],
171
  outputs=qa_output
172
  )
173
 
174
- #gr.Markdown("Built for Text Analytics Assignment - Section 02")
 
175
 
176
  demo.launch()
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
 
10
  import matplotlib.pyplot as plt
11
 
12
  # ====================== NLTK SETUP ======================
 
13
  nltk.download('wordnet', quiet=True)
14
  nltk.download('punkt', quiet=True)
15
  nltk.download('punkt_tab', quiet=True)
16
 
 
17
  lemmatizer = WordNetLemmatizer()
18
 
19
  def preprocess_text(text):
 
21
  return ""
22
  text = text.lower()
23
  punct_to_remove = string.punctuation.replace("'","").replace('"',"").replace("$","").replace("%","").replace("?","")
24
+ text = re.sub(f"[{punct_to_remove}]", " ", text)
25
  tokens = nltk.word_tokenize(text)
 
26
  tokens = [lemmatizer.lemmatize(word) for word in tokens]
27
  return ' '.join(tokens)
28
 
 
29
  classifier_model = "Ginidu2003/Distilbert-Base-News-classifier"
30
 
31
+ # ====================== COLORED BAR CHART ======================
32
+ def create_colored_bar_chart(category_counts):
33
+ if category_counts is None or len(category_counts) == 0:
34
+ fig, ax = plt.subplots()
35
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
36
+ return fig
37
+
38
+ categories = category_counts["Category"]
39
+ counts = category_counts["Count"]
40
 
41
+ colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD']
42
 
43
+ fig, ax = plt.subplots(figsize=(10, 6))
44
+ bars = ax.bar(categories, counts, color=colors)
45
+
46
+ for bar in bars:
47
+ height = bar.get_height()
48
+ ax.text(bar.get_x() + bar.get_width()/2, height + 0.5,
49
+ str(int(height)), ha='center', va='bottom', fontsize=12, fontweight='bold')
50
+
51
+ ax.set_title("Category Distribution Across 5 Classes", fontsize=16, fontweight='bold')
52
+ ax.set_xlabel("Category", fontsize=12)
53
+ ax.set_ylabel("Count", fontsize=12)
54
+ plt.xticks(rotation=15)
55
+ plt.tight_layout()
56
+ return fig
57
 
58
  # ====================== CLASSIFICATION FUNCTION ======================
59
  @torch.no_grad()
 
61
  try:
62
  df = pd.read_csv(file)
63
  if 'content' not in df.columns:
64
+ return "Error: CSV must have a column named 'content'", None, None
65
 
66
  df['clean_content'] = df['content'].apply(preprocess_text)
67
 
 
80
 
81
  output_file = "output.csv"
82
  df.to_csv(output_file, index=False)
83
+
 
84
  category_counts = df['class'].value_counts().reset_index()
85
  category_counts.columns = ["Category", "Count"]
86
+
 
87
  fig = create_colored_bar_chart(category_counts)
88
 
89
  return f"βœ… Success! Classified {len(df)} rows", output_file, fig
90
  except Exception as e:
91
  return f"❌ Error: {str(e)}", None, None
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  # ====================== Q&A FUNCTION ======================
94
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
 
95
  qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
96
  qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
97
+
98
  def answer_question(news_content, question):
99
  if not news_content.strip() or not question.strip():
100
  return "Please enter both news content and a question."
 
101
  try:
102
  inputs = qa_tokenizer(question, news_content, return_tensors="pt", truncation=True, max_length=512)
 
103
  with torch.no_grad():
104
  outputs = qa_model(**inputs)
105
 
106
  start_idx = torch.argmax(outputs.start_logits)
107
  end_idx = torch.argmax(outputs.end_logits) + 1
108
 
109
+ answer = qa_tokenizer.decode(inputs.input_ids[0][start_idx:end_idx],
110
+ skip_special_tokens=True,
 
111
  clean_up_tokenization_spaces=True)
112
 
113
  confidence = torch.max(torch.softmax(outputs.start_logits, dim=1)).item()
114
 
115
  return f"**Answer:** {answer.strip()}\n\n**Confidence:** {confidence:.2%}"
 
116
  except Exception as e:
117
  return f"Error: {str(e)}"
118
 
119
+ # ====================== BEAUTIFUL GRADIO INTERFACE ======================
120
+ with gr.Blocks(
121
+ title="English News Classifier",
122
+ theme=gr.themes.Soft(), # Beautiful modern theme
123
+ css="""
124
+ .gradio-container {max-width: 1100px; margin: auto;}
125
+ h1 {font-size: 2.5rem; text-align: center;}
126
+ .tab-label {font-size: 1.1rem; font-weight: 600;}
127
+ """
128
+ ) as demo:
129
+
130
  gr.Markdown("# πŸ“° English News Classifier")
131
+ gr.Markdown("### Intelligent News Analysis Tool | Daily Mirror Sri Lanka")
132
 
133
  with gr.Tabs():
134
+ # ====================== CLASSIFICATION TAB ======================
135
  with gr.Tab("πŸ“Š News Classification"):
136
+ gr.Markdown("### Upload CSV and get automatic category prediction")
 
 
 
 
 
 
 
 
 
137
 
138
+ with gr.Row():
139
+ file_input = gr.File(
140
+ label="πŸ“€ Upload your CSV file",
141
+ file_types=[".csv"],
142
+ height=120
143
+ )
144
+
145
+ classify_btn = gr.Button("πŸš€ Classify News", variant="primary", size="large")
146
+
147
+ with gr.Row():
148
+ output_text = gr.Textbox(label="Status", scale=2)
149
+ output_file = gr.File(label="πŸ“₯ Download output.csv")
150
+
151
+ bar_chart = gr.Plot(label="πŸ“Š Category Distribution Across 5 Classes")
152
+
153
  classify_btn.click(
154
  fn=classify_csv,
155
  inputs=file_input,
156
+ outputs=[output_text, output_file, bar_chart]
157
  )
158
 
159
+ # ====================== Q&A TAB ======================
160
  with gr.Tab("❓ Question Answering"):
161
+ gr.Markdown("### Ask any question about a news article")
 
 
 
 
162
 
163
+ news_input = gr.Textbox(
164
+ lines=10,
165
+ label="πŸ“ Paste News Content",
166
+ placeholder="Paste the full news article here..."
167
+ )
168
+ question_input = gr.Textbox(
169
+ label="❓ Your Question",
170
+ placeholder="e.g., What is the main issue? Who is involved?"
171
+ )
172
+
173
+ qa_btn = gr.Button("πŸ” Get Answer", variant="primary", size="large")
174
+ qa_output = gr.Textbox(label="πŸ’‘ Answer", lines=6)
175
+
176
  qa_btn.click(
177
  fn=answer_question,
178
  inputs=[news_input, question_input],
179
  outputs=qa_output
180
  )
181
 
182
+ gr.Markdown("---")
183
+ gr.Markdown("**Built for Text Analytics Assignment (In23-S5-DA3111) - Section 02**")
184
 
185
  demo.launch()