Breadknife commited on
Commit
d4e6067
·
1 Parent(s): b5e269a

feat: Implement and fix the news web scraping application

Browse files
Files changed (5) hide show
  1. NewsApex +1 -0
  2. app.py +226 -128
  3. bias_module/evaluate.py +42 -15
  4. bias_module/load_data.py +20 -3
  5. evaluation_results.txt +11 -0
NewsApex ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 31d4e764b5898ca393eb0a0343bc026a052db6d7
app.py CHANGED
@@ -36,15 +36,69 @@ st.markdown("""
36
  visibility: hidden !important;
37
  } */
38
  .main {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  background-color: #f8f9fa;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  }
41
  .stArticle {
42
  background-color: white;
43
- padding: 2rem;
44
- border-radius: 10px;
45
  border: 1px solid #e9ecef;
46
- margin-bottom: 2rem;
47
- box-shadow: 0 2px 4px rgba(0,0,0,0.05);
48
  }
49
  .full-content {
50
  color: #212529;
@@ -52,9 +106,9 @@ st.markdown("""
52
  line-height: 1.6;
53
  margin-top: 1rem;
54
  margin-bottom: 1rem;
55
- padding: 1.5rem;
56
- background-color: #ffffff;
57
- border-left: 5px solid #007bff;
58
  white-space: pre-wrap;
59
  }
60
  .summary-content {
@@ -97,164 +151,208 @@ def cached_split_into_sentences(_service, text):
97
  @st.cache_data(show_spinner=False)
98
  def cached_rate_bias(_service, text):
99
  return _service.rate_bias(text)
100
- # ------------------------
101
 
102
- def fetch_and_display_news(query, news_service, title=None, bias_mode=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if title:
104
  st.subheader(title)
105
 
106
- spinner_text = f"🔍 Searching for articles about '{query}'..." if query else "🔍 Fetching latest headlines..."
107
  with st.spinner(spinner_text):
108
  try:
109
- # Use cached news fetching
110
  articles = cached_fetch_all_news(news_service, query, language="en")
111
 
112
  if not articles:
113
  st.error("No articles found for this topic. Please try another one.")
114
- else:
115
- # Limit to 10 articles to avoid endless fetching/loading
116
- articles = articles[:10]
117
-
118
- if query:
119
- st.success(f"✅ Found {len(articles)} articles!")
120
-
121
- found_any = False
122
- for idx, article in enumerate(articles):
123
- url = article.get('link')
124
- if not url:
125
- continue
 
 
 
 
126
 
127
- # Use cached content extraction
128
- full_content = cached_get_full_content(news_service, url)
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- if full_content:
131
- found_any = True
132
- # Use cached summarization and sentence splitting
133
- summary = cached_summarize_content(news_service, full_content)
134
-
135
- with st.container():
136
- st.markdown('<div class="stArticle">', unsafe_allow_html=True)
137
- st.header(article.get('title', 'No Title'))
138
- st.markdown(f"""
139
- <div class="meta-info">
140
- <b>Source:</b> {article.get('source_id', 'Unknown')} |
141
- <b>Date:</b> {article.get('pubDate', 'Unknown')} |
142
- <a href="{url}" target="_blank">Original Link</a>
143
- </div>
144
- """, unsafe_allow_html=True)
145
-
146
- if bias_mode:
147
- st.subheader("⚖️ Bias Analysis")
148
- overall_bias = cached_rate_bias(news_service, full_content)
149
-
150
- # Display overall rating
151
- col_b1, col_b2 = st.columns([1, 3])
152
- with col_b1:
153
- color = "red" if overall_bias['label'] == "Biased" else "green"
154
- st.markdown(f"**Overall Rating:** <span style='color:{color}; font-weight:bold;'>{overall_bias['label']}</span>", unsafe_allow_html=True)
155
- with col_b2:
156
- st.progress(overall_bias['score'], text=f"Confidence: {overall_bias['score']:.1%}")
157
-
158
- st.subheader("📋 Full Article (Numbered Sentences for Bias Detection)")
159
- sentences = cached_split_into_sentences(news_service, full_content)
160
- # Display sentences as a numbered list
161
- sentence_html = ""
162
- for i, sentence in enumerate(sentences, 1):
163
- # Rate individual sentence bias
164
- s_bias = cached_rate_bias(news_service, sentence)
165
- s_label = s_bias.get('label', 'Factual')
166
- s_color = "rgba(255, 0, 0, 0.1)" if s_label == "Biased" else "transparent"
167
-
168
- # Escape each sentence
169
- escaped_sentence = html.escape(sentence)
170
- # Use a single line to avoid accidental markdown code block triggers
171
- # We access the dictionary keys outside the f-string to avoid quoting issues
172
- sentence_html += f'<div style="margin-bottom: 8px; padding: 4px; background-color: {s_color}; border-radius: 4px;"><b>{i}.</b> {escaped_sentence} <span style="font-size: 0.8rem; color: gray; margin-left: 10px;">({s_label})</span></div>'
173
-
174
- st.markdown(f'<div class="full-content">{sentence_html}</div>', unsafe_allow_html=True)
175
- else:
176
- st.subheader("Full Article")
177
- # Escape the full content to ensure it's not cut off by HTML tags
178
- escaped_content = html.escape(full_content)
179
- st.markdown(f'<div class="full-content">{escaped_content}</div>', unsafe_allow_html=True)
180
-
181
- if summary:
182
- st.subheader("Summarization")
183
- st.markdown(f'<div class="summary-content">🤖 {summary}</div>', unsafe_allow_html=True)
184
-
185
- st.markdown('</div>', unsafe_allow_html=True)
186
- st.divider()
187
-
188
- if not found_any:
189
- st.warning("Could not retrieve full content for any articles. The sources might be protected or paywalled.")
190
  except Exception as e:
191
- st.error(f"An error occurred while fetching news: {str(e)}")
192
 
193
  def main():
194
  st.title("📰 NEXTER")
195
- st.markdown("Your AI-powered gateway to the latest news from around the world.")
196
 
197
- # Initialize session state for query and search trigger
198
- if 'search_query' not in st.session_state:
199
- st.session_state.search_query = ""
200
- if 'trigger_search' not in st.session_state:
201
- st.session_state.trigger_search = False
202
- if 'is_home' not in st.session_state:
203
- st.session_state.is_home = True
204
 
205
- # Initialize Service (Cached)
206
  with st.sidebar:
207
  st.header("Search Settings")
208
  news_service = get_news_service()
209
 
210
- query_input = st.text_input("Enter Topic", value=st.session_state.search_query, placeholder="e.g. Artificial Intelligence, Space, Finance")
211
 
212
- # Bias Detection Toggle
213
- bias_mode = st.toggle("Bias Detection Mode", help="Enable to see articles broken down into numbered sentences.")
214
-
215
- # Update session state if input changes manually
216
  if query_input != st.session_state.search_query:
217
  st.session_state.search_query = query_input
 
218
  st.session_state.is_home = False
219
 
220
- # Manual search button
221
  if st.button("Fetch News", type="primary"):
222
- st.session_state.trigger_search = True
223
  st.session_state.is_home = False
224
-
225
- # Clear Cache button
226
- if st.button("🧹 Clear Cache"):
227
- st.cache_data.clear()
228
- st.success("Cache cleared! Reloading...")
229
- time.sleep(0.5)
230
  st.rerun()
231
 
232
- # Reset to Home button
233
- if st.button("🏠 Home"):
234
  st.session_state.search_query = ""
235
- st.session_state.trigger_search = False
236
  st.session_state.is_home = True
237
  st.rerun()
238
 
239
- # Model Status indicator
 
 
 
 
 
 
240
  if news_service.bias_model:
241
- st.success("✅ Local Model Loaded")
242
  else:
243
- st.info("☁️ Using Cloud Fallback")
244
-
245
- # Handle search execution
246
- if st.session_state.trigger_search and st.session_state.search_query:
247
- st.session_state.is_home = False # Ensure we're not in home view
248
- st.write(f"DEBUG: Searching for {st.session_state.search_query}") # DEBUG
249
- fetch_and_display_news(st.session_state.search_query, news_service, title="Search Results", bias_mode=bias_mode)
250
- st.session_state.trigger_search = False # Reset trigger after search
251
- elif st.session_state.trigger_search and not st.session_state.search_query:
252
- st.sidebar.error("Please enter a topic!")
253
- st.session_state.trigger_search = False
254
-
255
- # Home View: Latest Headlines
256
- if st.session_state.is_home:
257
- fetch_and_display_news(None, news_service, title="Latest Headlines", bias_mode=bias_mode)
258
 
259
 
260
  if __name__ == "__main__":
 
36
  visibility: hidden !important;
37
  } */
38
  .main {
39
+ background-color: #f0f2f6;
40
+ }
41
+ .stCard {
42
+ background-color: white;
43
+ padding: 1rem;
44
+ border-radius: 12px;
45
+ border: 1px solid #e0e0e0;
46
+ margin-bottom: 1.5rem;
47
+ box-shadow: 0 4px 6px rgba(0,0,0,0.05);
48
+ transition: transform 0.2s ease-in-out;
49
+ height: 480px; /* Fixed height for alignment */
50
+ display: flex;
51
+ flex-direction: column;
52
+ justify-content: space-between;
53
+ }
54
+ .stCard:hover {
55
+ transform: translateY(-5px);
56
+ box-shadow: 0 8px 15px rgba(0,0,0,0.1);
57
+ }
58
+ .card-img {
59
+ width: 100%;
60
+ height: 200px; /* Fixed image height */
61
+ object-fit: cover;
62
+ border-radius: 8px;
63
+ margin-bottom: 0.8rem;
64
  background-color: #f8f9fa;
65
+ display: block;
66
+ }
67
+ .card-img-placeholder {
68
+ width: 100%;
69
+ height: 200px;
70
+ background-color: #e9ecef;
71
+ border-radius: 8px;
72
+ margin-bottom: 0.8rem;
73
+ display: flex;
74
+ align-items: center;
75
+ justify-content: center;
76
+ color: #adb5bd;
77
+ font-size: 0.9rem;
78
+ }
79
+ .card-title {
80
+ font-size: 1.1rem;
81
+ font-weight: 700;
82
+ color: #1a1a1a;
83
+ margin-bottom: 0.5rem;
84
+ line-height: 1.4;
85
+ height: 4.2rem; /* Fixed height for 3 lines of text */
86
+ display: -webkit-box;
87
+ -webkit-line-clamp: 3;
88
+ -webkit-box-orient: vertical;
89
+ overflow: hidden;
90
+ }
91
+ .card-meta {
92
+ font-size: 0.8rem;
93
+ color: #6c757d;
94
+ margin-bottom: 0.5rem;
95
  }
96
  .stArticle {
97
  background-color: white;
98
+ padding: 2.5rem;
99
+ border-radius: 15px;
100
  border: 1px solid #e9ecef;
101
+ box-shadow: 0 10px 25px rgba(0,0,0,0.05);
 
102
  }
103
  .full-content {
104
  color: #212529;
 
106
  line-height: 1.6;
107
  margin-top: 1rem;
108
  margin-bottom: 1rem;
109
+ padding: 0;
110
+ background-color: transparent;
111
+ border-left: none;
112
  white-space: pre-wrap;
113
  }
114
  .summary-content {
 
151
  @st.cache_data(show_spinner=False)
152
  def cached_rate_bias(_service, text):
153
  return _service.rate_bias(text)
 
154
 
155
+ def display_article_detail(article, news_service):
156
+ """Displays the detailed view of a selected article with bias analysis and summary."""
157
+ if st.button("← Back to Feed"):
158
+ st.session_state.selected_article = None
159
+ st.rerun()
160
+
161
+ url = article.get('link')
162
+ if not url:
163
+ st.error("Invalid article link.")
164
+ return
165
+
166
+ st.markdown('<div class="stArticle">', unsafe_allow_html=True)
167
+ st.title(article.get('title', 'No Title'))
168
+
169
+ # Hero image
170
+ img_url = article.get('image_url')
171
+ if img_url:
172
+ st.image(img_url, use_container_width=True)
173
+
174
+ st.markdown(f"""
175
+ <div class="meta-info">
176
+ <b>Source:</b> {article.get('source_id', 'Unknown')} |
177
+ <b>Date:</b> {article.get('pubDate', 'Unknown')} |
178
+ <a href="{url}" target="_blank">Original Link</a>
179
+ </div>
180
+ """, unsafe_allow_html=True)
181
+
182
+ with st.spinner("🧠 Analyzing article content..."):
183
+ full_content = cached_get_full_content(news_service, url)
184
+
185
+ # --- ROBUST FALLBACK LOGIC ---
186
+ text_to_analyze = full_content
187
+ status_msg = None
188
+
189
+ # 1. Try Full Content
190
+ if not text_to_analyze or len(text_to_analyze.strip()) < 100:
191
+ # 2. Try Snippet/Description
192
+ text_to_analyze = article.get('snippet')
193
+ if text_to_analyze and len(text_to_analyze.strip()) > 20:
194
+ status_msg = "⚠️ Full content extraction limited. Analyzing article snippet."
195
+ else:
196
+ # 3. Last Resort: Use Title
197
+ text_to_analyze = article.get('title', '')
198
+ status_msg = "⚠️ No content found. Analyzing article headline only."
199
+
200
+ if not text_to_analyze:
201
+ st.error("Could not retrieve any text for this article.")
202
+ return
203
+
204
+ if status_msg:
205
+ st.warning(status_msg)
206
+
207
+ # Display the text being analyzed
208
+ with st.expander("📖 View Analyzed Text", expanded=True):
209
+ st.markdown(f'<div style="font-size: 1rem; line-height: 1.5; color: #333;">{html.escape(text_to_analyze)}</div>', unsafe_allow_html=True)
210
+
211
+ # AI Summary
212
+ summary = cached_summarize_content(news_service, text_to_analyze)
213
+ if summary:
214
+ st.info(f"🤖 **AI Summary:** {summary}")
215
+
216
+ st.divider()
217
+
218
+ # Bias Analysis
219
+ st.subheader("⚖️ Bias Analysis")
220
+ overall_bias = cached_rate_bias(news_service, text_to_analyze)
221
+
222
+ col_b1, col_b2 = st.columns([1, 3])
223
+ with col_b1:
224
+ color = "#dc3545" if overall_bias['label'] == "Biased" else "#28a745"
225
+ st.markdown(f"**Overall Rating:** <span style='color:{color}; font-weight:bold; font-size:1.2rem;'>{overall_bias['label']}</span>", unsafe_allow_html=True)
226
+ with col_b2:
227
+ st.progress(overall_bias['score'], text=f"Confidence: {overall_bias['score']:.1%}")
228
+
229
+ # --- Interpretation Section ---
230
+ if overall_bias['label'] == "Factual":
231
+ st.success("✅ **Interpretation: Factual Content**\nThis article primarily uses objective language, reports verifiable events, and avoids subjective modifiers or emotional framing. It aims to inform rather than influence.")
232
+ else:
233
+ st.error("⚠️ **Interpretation: Biased Content**\nThis article contains elements that suggest a non-neutral perspective. This could include the use of loaded language, emotional appeals, or selective framing designed to influence the reader's opinion.")
234
+
235
+ o_reasoning = overall_bias.get('reasoning', 'No specific reasoning provided.')
236
+ st.warning(f"💡 **Analysis Reasoning:** {o_reasoning}")
237
+
238
+ st.subheader("📋 Sentence-by-Sentence Breakdown")
239
+ sentences = cached_split_into_sentences(news_service, text_to_analyze)
240
+
241
+ sentence_html = ""
242
+ for i, sentence in enumerate(sentences, 1):
243
+ s_bias = cached_rate_bias(news_service, sentence)
244
+ s_label = s_bias.get('label', 'Factual')
245
+ s_reasoning = s_bias.get('reasoning', '')
246
+ s_color = "rgba(220, 53, 69, 0.08)" if s_label == "Biased" else "transparent"
247
+
248
+ escaped_sentence = html.escape(sentence)
249
+ escaped_reasoning = html.escape(s_reasoning)
250
+
251
+ reasoning_html = f'<div style="font-size: 0.85rem; color: #721c24; margin-top: 4px; font-style: italic;">Why? {escaped_reasoning}</div>' if s_label == "Biased" else ""
252
+
253
+ border_style = "border-left: 4px solid #dc3545;" if s_label == "Biased" else "border-left: 4px solid #e9ecef;"
254
+ sentence_html += f'<div style="margin-bottom: 12px; padding: 12px; background-color: {s_color}; border-radius: 6px; {border_style}"><b>{i}.</b> {escaped_sentence} <span style="font-size: 0.8rem; color: #6c757d; margin-left: 10px; font-weight: bold;">[{s_label}]</span>{reasoning_html}</div>'
255
+
256
+ st.markdown(f'<div class="full-content" style="border:none; padding:0;">{sentence_html}</div>', unsafe_allow_html=True)
257
+
258
+ st.markdown('</div>', unsafe_allow_html=True)
259
+
260
+ def fetch_and_display_news(query, news_service, title=None):
261
+ """Fetches and displays news in a grid layout."""
262
  if title:
263
  st.subheader(title)
264
 
265
+ spinner_text = f"🔍 Searching for '{query}'..." if query else "🔍 Fetching latest headlines..."
266
  with st.spinner(spinner_text):
267
  try:
 
268
  articles = cached_fetch_all_news(news_service, query, language="en")
269
 
270
  if not articles:
271
  st.error("No articles found for this topic. Please try another one.")
272
+ return
273
+
274
+ # Grid layout: 3 columns
275
+ articles = articles[:12]
276
+ cols = st.columns(3)
277
+ for idx, article in enumerate(articles):
278
+ with cols[idx % 3]:
279
+ st.markdown('<div class="stCard">', unsafe_allow_html=True)
280
+
281
+ # Thumbnail Image handling
282
+ img_url = article.get('image_url')
283
+ if img_url and img_url.startswith('http'):
284
+ st.markdown(f'<img src="{img_url}" class="card-img" onerror="this.style.display=\'none\'; this.nextSibling.style.display=\'flex\';">', unsafe_allow_html=True)
285
+ st.markdown('<div class="card-img-placeholder" style="display:none;">🖼️ Image Unavailable</div>', unsafe_allow_html=True)
286
+ else:
287
+ st.markdown('<div class="card-img-placeholder">🖼️ No Image</div>', unsafe_allow_html=True)
288
 
289
+ # Source and Date
290
+ date_str = article.get('pubDate', 'Unknown')[:10]
291
+ st.markdown(f'<div class="card-meta">{article.get("source_id", "Unknown")} • {date_str}</div>', unsafe_allow_html=True)
292
+
293
+ # Title
294
+ st.markdown(f'<div class="card-title">{article.get("title", "No Title")}</div>', unsafe_allow_html=True)
295
+
296
+ # Analyze Button
297
+ if st.button("Analyze Article", key=f"btn_{idx}", use_container_width=True):
298
+ st.session_state.selected_article = article
299
+ st.rerun()
300
+
301
+ st.markdown('</div>', unsafe_allow_html=True)
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  except Exception as e:
304
+ st.error(f"Error fetching news: {str(e)}")
305
 
306
  def main():
307
  st.title("📰 NEXTER")
308
+ st.markdown("Modern AI-Powered News Analysis")
309
 
310
+ # Initialize session states
311
+ if 'search_query' not in st.session_state: st.session_state.search_query = ""
312
+ if 'selected_article' not in st.session_state: st.session_state.selected_article = None
313
+ if 'is_home' not in st.session_state: st.session_state.is_home = True
 
 
 
314
 
 
315
  with st.sidebar:
316
  st.header("Search Settings")
317
  news_service = get_news_service()
318
 
319
+ query_input = st.text_input("Topic Search", value=st.session_state.search_query, placeholder="e.g. Finance, AI, Sports")
320
 
 
 
 
 
321
  if query_input != st.session_state.search_query:
322
  st.session_state.search_query = query_input
323
+ st.session_state.selected_article = None
324
  st.session_state.is_home = False
325
 
 
326
  if st.button("Fetch News", type="primary"):
327
+ st.session_state.selected_article = None
328
  st.session_state.is_home = False
 
 
 
 
 
 
329
  st.rerun()
330
 
331
+ if st.button("🏠 Home Feed"):
 
332
  st.session_state.search_query = ""
333
+ st.session_state.selected_article = None
334
  st.session_state.is_home = True
335
  st.rerun()
336
 
337
+ st.divider()
338
+ if st.button("🧹 Clear Cache"):
339
+ st.cache_data.clear()
340
+ st.success("Cache cleared!")
341
+ time.sleep(0.5)
342
+ st.rerun()
343
+
344
  if news_service.bias_model:
345
+ st.success("✅ AI Bias Model Loaded")
346
  else:
347
+ st.info("☁️ Cloud Analysis Active")
348
+
349
+ # Display logic
350
+ if st.session_state.selected_article:
351
+ display_article_detail(st.session_state.selected_article, news_service)
352
+ elif st.session_state.is_home:
353
+ fetch_and_display_news(None, news_service, title="Top Headlines")
354
+ else:
355
+ fetch_and_display_news(st.session_state.search_query, news_service, title=f"Results for: {st.session_state.search_query}")
 
 
 
 
 
 
356
 
357
 
358
  if __name__ == "__main__":
bias_module/evaluate.py CHANGED
@@ -24,21 +24,36 @@ print("Using device:", device)
24
 
25
  # -----------------------------
26
  # Load tokenizer and model
27
- tokenizer = BertTokenizer.from_pretrained(config.MODEL_NAME)
28
- model = BertForSequenceClassification.from_pretrained(
29
- config.MODEL_NAME,
30
- num_labels=2
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Correct model path regardless of current working directory
34
- model_path = os.path.join(PROJECT_ROOT, "models", "bert_babe.pt")
35
- if not os.path.exists(model_path):
36
- raise FileNotFoundError(f"Model file not found at {model_path}")
37
 
38
- model.load_state_dict(torch.load(model_path, map_location=device))
39
- model.to(device)
40
- model.eval()
41
- print(f"Model loaded successfully from {model_path}")
 
 
 
42
 
43
  # -----------------------------
44
  # Load dataset and create dataloaders
@@ -48,11 +63,14 @@ _, test_loader = create_dataloaders(dataset, batch_size=config.BATCH_SIZE)
48
  # -----------------------------
49
  # Evaluation function
50
  def evaluate_model(model, test_loader):
 
51
  true_labels = []
52
  predicted_labels = []
53
 
54
  with torch.no_grad():
55
- for batch in test_loader:
 
 
56
  input_ids, attention_mask, labels = [x.to(device) for x in batch]
57
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
58
  predictions = torch.argmax(outputs.logits, dim=1)
@@ -60,10 +78,19 @@ def evaluate_model(model, test_loader):
60
  true_labels.extend(labels.cpu().numpy())
61
  predicted_labels.extend(predictions.cpu().numpy())
62
 
 
63
  acc = accuracy_score(true_labels, predicted_labels)
 
 
 
 
 
 
 
 
64
  print(f"\nTest Accuracy: {acc:.4f}")
65
  print("\nClassification Report:")
66
- print(classification_report(true_labels, predicted_labels, target_names=["factual", "biased"]))
67
 
68
  # -----------------------------
69
  # Run evaluation
 
24
 
25
  # -----------------------------
26
  # Load tokenizer and model
27
+ print("Loading model and tokenizer...")
28
+ try:
29
+ model_cache_dir = os.path.join(PROJECT_ROOT, "bias_module", "data", "model_cache")
30
+ if os.path.exists(model_cache_dir):
31
+ print(f"Loading from local cache: {model_cache_dir}")
32
+ tokenizer = BertTokenizer.from_pretrained(model_cache_dir)
33
+ model = BertForSequenceClassification.from_pretrained(
34
+ model_cache_dir,
35
+ num_labels=2
36
+ )
37
+ else:
38
+ print(f"Loading from HF Hub: {config.MODEL_NAME}")
39
+ tokenizer = BertTokenizer.from_pretrained(config.MODEL_NAME)
40
+ model = BertForSequenceClassification.from_pretrained(
41
+ config.MODEL_NAME,
42
+ num_labels=2
43
+ )
44
 
45
+ # Correct model path regardless of current working directory
46
+ model_path = os.path.join(PROJECT_ROOT, "models", "bert_babe.pt")
47
+ if not os.path.exists(model_path):
48
+ raise FileNotFoundError(f"Model file not found at {model_path}")
49
 
50
+ model.load_state_dict(torch.load(model_path, map_location=device))
51
+ model.to(device)
52
+ model.eval()
53
+ print(f"Model loaded successfully from {model_path}")
54
+ except Exception as e:
55
+ print(f"Error loading model: {e}")
56
+ sys.exit(1)
57
 
58
  # -----------------------------
59
  # Load dataset and create dataloaders
 
63
  # -----------------------------
64
  # Evaluation function
65
  def evaluate_model(model, test_loader):
66
+ print("Starting evaluation...")
67
  true_labels = []
68
  predicted_labels = []
69
 
70
  with torch.no_grad():
71
+ for i, batch in enumerate(test_loader):
72
+ if i % 10 == 0:
73
+ print(f"Processing batch {i}...")
74
  input_ids, attention_mask, labels = [x.to(device) for x in batch]
75
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
76
  predictions = torch.argmax(outputs.logits, dim=1)
 
78
  true_labels.extend(labels.cpu().numpy())
79
  predicted_labels.extend(predictions.cpu().numpy())
80
 
81
+ print("Calculating metrics...")
82
  acc = accuracy_score(true_labels, predicted_labels)
83
+ report = classification_report(true_labels, predicted_labels, target_names=["factual", "biased"])
84
+
85
+ # Save to file as well as print
86
+ with open("evaluation_results.txt", "w") as f:
87
+ f.write(f"Test Accuracy: {acc:.4f}\n")
88
+ f.write("\nClassification Report:\n")
89
+ f.write(report)
90
+
91
  print(f"\nTest Accuracy: {acc:.4f}")
92
  print("\nClassification Report:")
93
+ print(report)
94
 
95
  # -----------------------------
96
  # Run evaluation
bias_module/load_data.py CHANGED
@@ -10,14 +10,31 @@ if PROJECT_ROOT not in sys.path:
10
 
11
  # Now we can safely import config
12
  import config
13
- from datasets import load_dataset
 
14
 
15
  def load_babe_dataset():
16
  """
17
- Load the BABE dataset from Hugging Face and split into train/test.
 
18
  Returns:
19
- dataset: dict with 'train' and 'test' splits
20
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Load the full dataset
22
  dataset = load_dataset(config.DATASET_NAME)
23
 
 
10
 
11
  # Now we can safely import config
12
  import config
13
+ from datasets import load_dataset, DatasetDict, Dataset
14
+ import pandas as pd
15
 
16
  def load_babe_dataset():
17
  """
18
+ Load the BABE dataset from local parquet files if available,
19
+ otherwise from Hugging Face.
20
  Returns:
21
+ dataset: DatasetDict with 'train' and 'test' splits
22
  """
23
+ local_train = os.path.join(PROJECT_ROOT, "bias_module", "data", "cache", "data", "train-00000-of-00001.parquet")
24
+ local_test = os.path.join(PROJECT_ROOT, "bias_module", "data", "cache", "data", "test-00000-of-00001.parquet")
25
+
26
+ if os.path.exists(local_train) and os.path.exists(local_test):
27
+ print("Loading BABE dataset from local parquet files...")
28
+ train_df = pd.read_parquet(local_train)
29
+ test_df = pd.read_parquet(local_test)
30
+
31
+ dataset = DatasetDict({
32
+ "train": Dataset.from_pandas(train_df),
33
+ "test": Dataset.from_pandas(test_df)
34
+ })
35
+ return dataset
36
+
37
+ print(f"Loading BABE dataset from Hugging Face ({config.DATASET_NAME})...")
38
  # Load the full dataset
39
  dataset = load_dataset(config.DATASET_NAME)
40
 
evaluation_results.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Test Accuracy: 0.8200
2
+
3
+ Classification Report:
4
+ precision recall f1-score support
5
+
6
+ factual 0.74 0.91 0.82 441
7
+ biased 0.91 0.75 0.82 559
8
+
9
+ accuracy 0.82 1000
10
+ macro avg 0.83 0.83 0.82 1000
11
+ weighted avg 0.84 0.82 0.82 1000