Breadknife commited on
Commit
2a62a5a
·
1 Parent(s): f885013

chore: Prepare NewsApex for Vercel deployment with self-contained Python logic

Browse files
NewsApex/bridge_logic.py CHANGED
@@ -4,10 +4,10 @@ import os
4
  import json
5
  import argparse
6
 
7
- # Add parent directory to sys.path to access our news_service and model
8
- parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
9
- if parent_dir not in sys.path:
10
- sys.path.insert(0, parent_dir)
11
 
12
  from news_service import NewsService
13
 
 
4
  import json
5
  import argparse
6
 
7
+ # Add current directory to sys.path to access our news_service and model locally
8
+ current_dir = os.path.dirname(os.path.abspath(__file__))
9
+ if current_dir not in sys.path:
10
+ sys.path.insert(0, current_dir)
11
 
12
  from news_service import NewsService
13
 
NewsApex/news_service.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import re
4
+ from dotenv import load_dotenv
5
+ from newspaper import Article, Config
6
+ from huggingface_hub import InferenceClient
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import sys
9
+
10
+ # Add current directory to path for bias_module integration
11
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+ load_dotenv()
14
+
15
+ class NewsService:
16
+ # Common English stop words to filter from biased words list
17
+ STOP_WORDS = {
18
+ 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'with',
19
+ 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had',
20
+ 'do', 'does', 'did', 'but', 'if', 'then', 'else', 'when', 'where', 'why',
21
+ 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some',
22
+ 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very',
23
+ 's', 't', 'can', 'will', 'just', 'don', "should", "now", "of", "as", "by",
24
+ "it", "its", "they", "them", "their", "this", "that", "these", "those",
25
+ "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your"
26
+ }
27
+
28
+ def __init__(self):
29
+ # API keys from .env with fallbacks for ease of use
30
+ self.newsdata_api_key = os.getenv("NEWSDATA_API_KEY", "pub_c319de1ec46240dc912d9b112e01c866")
31
+ self.guardian_api_key = os.getenv("GUARDIAN_API_KEY", "438ab5df-f19b-42b6-9ca9-83b8e971f219")
32
+ self.hf_token = os.getenv("HF_TOKEN")
33
+
34
+ self.session = requests.Session()
35
+ self.config = Config()
36
+ self.config.browser_user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36'
37
+ self.config.request_timeout = 20
38
+ self.config.fetch_images = False
39
+ self.config.memoize_articles = False
40
+ self.config.MAX_TEXT = 100000 # Ensure we get as much text as possible
41
+
42
+ # Initialize HF Client for summarization and fallback bias detection
43
+ if self.hf_token:
44
+ try:
45
+ self.hf_client = InferenceClient(token=self.hf_token)
46
+ except:
47
+ self.hf_client = None
48
+ else:
49
+ self.hf_client = None
50
+
51
+ # --- BIAS MODULE INTEGRATION ---
52
+ self.bias_model = None
53
+ self.bias_tokenizer = None
54
+ # Model will be loaded on first use of rate_bias
55
+
56
+ def load_local_bias_model(self):
57
+ """Attempts to load the local BERT model for bias detection."""
58
+ global torch, BertTokenizer, BertForSequenceClassification
59
+ try:
60
+ import torch
61
+ from transformers import BertTokenizer, BertForSequenceClassification
62
+ from bias_module import config as bias_config
63
+
64
+ # Use absolute paths to ensure the model loads correctly regardless of where the service is started
65
+ base_path = os.path.dirname(os.path.abspath(__file__))
66
+ model_path = os.path.join(base_path, "bias_module", "models", "bert_babe.pt")
67
+ model_cache_dir = os.path.join(base_path, "bias_module", "data", "model_cache")
68
+
69
+ if os.path.exists(model_path):
70
+ # Load from local cache if possible to avoid HF connectivity issues
71
+ if os.path.exists(model_cache_dir):
72
+ self.bias_tokenizer = BertTokenizer.from_pretrained(model_cache_dir)
73
+ self.bias_model = BertForSequenceClassification.from_pretrained(
74
+ model_cache_dir,
75
+ num_labels=2
76
+ )
77
+ else:
78
+ self.bias_tokenizer = BertTokenizer.from_pretrained(bias_config.MODEL_NAME)
79
+ self.bias_model = BertForSequenceClassification.from_pretrained(
80
+ bias_config.MODEL_NAME,
81
+ num_labels=2
82
+ )
83
+
84
+ self.bias_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
85
+ self.bias_model.eval()
86
+ # Print to stderr so it doesn't break JSON parsing in the bridge
87
+ print(f"Local bias model loaded successfully from {model_path}", file=sys.stderr)
88
+ except Exception as e:
89
+ print(f"Error loading local bias model: {e}", file=sys.stderr)
90
+ self.bias_model = None
91
+ self.bias_tokenizer = None
92
+
93
+ def get_top_biased_words_gradient(self, text, top_k=5):
94
+ """
95
+ Calculates word importance using gradients, matching the formula in bias_module/predict.py.
96
+ """
97
+ if not self.bias_model or not self.bias_tokenizer:
98
+ return []
99
+
100
+ try:
101
+ encoding = self.bias_tokenizer(
102
+ text,
103
+ return_tensors="pt",
104
+ truncation=True,
105
+ padding="max_length",
106
+ max_length=128
107
+ )
108
+
109
+ input_ids = encoding["input_ids"]
110
+ attention_mask = encoding["attention_mask"]
111
+
112
+ # Enable gradients ONLY for the embedding layer to save memory and time
113
+ self.bias_model.zero_grad()
114
+
115
+ # We only need gradients for the word embeddings to calculate importance
116
+ for param in self.bias_model.bert.embeddings.word_embeddings.parameters():
117
+ param.requires_grad = True
118
+
119
+ # Use torch.enable_grad() since we might be inside a torch.no_grad() block
120
+ import torch
121
+ with torch.enable_grad():
122
+ outputs = self.bias_model(input_ids=input_ids, attention_mask=attention_mask)
123
+ logits = outputs.logits
124
+
125
+ # Focus on "biased" class (index 1)
126
+ bias_logit = logits[0, 1]
127
+ bias_logit.backward()
128
+
129
+ # Get gradients from embedding layer
130
+ embedding_grad = self.bias_model.bert.embeddings.word_embeddings.weight.grad
131
+ if embedding_grad is None:
132
+ return []
133
+
134
+ token_ids = input_ids[0]
135
+ token_grads = embedding_grad[token_ids]
136
+
137
+ # Importance score = L2 norm (as used in bias_module/predict.py)
138
+ scores = torch.norm(token_grads, dim=1)
139
+ tokens = self.bias_tokenizer.convert_ids_to_tokens(token_ids)
140
+
141
+ # Filter tokens
142
+ filtered = []
143
+ for tok, score in zip(tokens, scores):
144
+ if tok in ["[CLS]", "[SEP]", "[PAD]"] or tok.startswith("##"):
145
+ continue
146
+ if not any(c.isalnum() for c in tok):
147
+ continue
148
+ if tok.lower() in self.STOP_WORDS:
149
+ continue
150
+ if len(tok) < 3 and tok.lower() not in ['a', 'i']:
151
+ continue
152
+
153
+ filtered.append((tok, score.item()))
154
+
155
+ # Sort by importance
156
+ sorted_tokens = sorted(filtered, key=lambda x: x[1], reverse=True)
157
+ return [{"word": t[0], "score": round(t[1] * 100, 2)} for t in sorted_tokens[:top_k]]
158
+ except Exception as e:
159
+ print(f"Gradient Calculation Error: {e}", file=sys.stderr)
160
+ return []
161
+
162
+ def get_bias_reasoning(self, text, label, bias_score):
163
+ """
164
+ Provides detailed reasoning for bias classification using the labels from predict.py.
165
+ """
166
+ if bias_score > 0.7:
167
+ return "Interpretation: Strongly biased"
168
+ elif bias_score > 0.5:
169
+ return "Interpretation: Likely Biased"
170
+ else:
171
+ return "Interpretation: Likely Factual"
172
+
173
+ def rate_bias_batch(self, sentences):
174
+ """
175
+ Rates bias for multiple sentences in a single batch for maximum performance.
176
+ """
177
+ if not sentences:
178
+ return []
179
+
180
+ if not self.bias_model:
181
+ self.load_local_bias_model()
182
+
183
+ if not self.bias_model or not self.bias_tokenizer:
184
+ return [{"label": "Offline", "score": 0.0, "reasoning": "Model not available."} for _ in sentences]
185
+
186
+ try:
187
+ import torch
188
+ inputs = self.bias_tokenizer(sentences, return_tensors="pt", truncation=True, padding=True, max_length=128)
189
+
190
+ with torch.no_grad():
191
+ outputs = self.bias_model(**inputs)
192
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
193
+
194
+ results = []
195
+ for i, prob in enumerate(probs):
196
+ bias_score = prob[1].item()
197
+ label = "Biased" if bias_score > 0.5 else "Factual"
198
+ reasoning = self.get_bias_reasoning(sentences[i], label, bias_score)
199
+
200
+ results.append({
201
+ "label": label,
202
+ "score": bias_score,
203
+ "reasoning": reasoning
204
+ })
205
+ return results
206
+ except Exception as e:
207
+ print(f"Batch Analysis Error: {e}", file=sys.stderr)
208
+ return [self.rate_bias(s) for s in sentences]
209
+
210
+ def rate_bias(self, text):
211
+ """
212
+ Rates bias for a given text and provides detailed reasoning and top biased words.
213
+ """
214
+ if not self.bias_model:
215
+ self.load_local_bias_model()
216
+
217
+ if not text or len(text.strip()) < 10:
218
+ return {"label": "Neutral", "score": 0.0, "reasoning": "Text too short for analysis."}
219
+
220
+ filtered_sentences = self.split_into_sentences(text)
221
+ if not filtered_sentences:
222
+ return {"label": "Neutral", "score": 0.0, "reasoning": "No valid content found."}
223
+
224
+ filtered_text = " ".join(filtered_sentences)
225
+
226
+ if self.bias_model and self.bias_tokenizer:
227
+ try:
228
+ inputs = self.bias_tokenizer(filtered_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
229
+ with torch.no_grad():
230
+ outputs = self.bias_model(**inputs)
231
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
232
+
233
+ bias_score = probs[0][1].item()
234
+ label = "Biased" if bias_score > 0.5 else "Factual"
235
+
236
+ # Get top biased words for overall analysis
237
+ top_words = self.get_top_biased_words_gradient(filtered_text, top_k=8)
238
+ reasoning = self.get_bias_reasoning(filtered_text, label, bias_score)
239
+
240
+ return {
241
+ "label": label,
242
+ "score": bias_score,
243
+ "reasoning": reasoning,
244
+ "top_words": top_words
245
+ }
246
+ except Exception as e:
247
+ print(f"Local Model Prediction Error: {e}", file=sys.stderr)
248
+
249
+ return {"label": "Offline", "score": 0.0, "reasoning": "Model failed or is offline."}
250
+
251
+
252
+ def split_into_sentences(self, text):
253
+ """
254
+ Split text into a list of sentences and filter out advertisements/unwanted content.
255
+ """
256
+ if not text:
257
+ return []
258
+
259
+ # 1. Basic splitting
260
+ sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text.strip())
261
+
262
+ # 2. Advertisement and low-value sentence filtering
263
+ ad_keywords = [
264
+ "sign up for", "subscribe to", "advertisement", "promoted",
265
+ "sponsored", "click here", "follow us on", "read more",
266
+ "newsletter", "privacy policy", "terms of service", "cookies",
267
+ "related stories", "recommended for you", "check out our",
268
+ "all rights reserved", "photo by", "image credit", "copyright"
269
+ ]
270
+
271
+ filtered_sentences = []
272
+ for s in sentences:
273
+ s_clean = s.strip()
274
+ if not s_clean:
275
+ continue
276
+
277
+ # Skip very short sentences (often UI fragments)
278
+ if len(s_clean.split()) < 4:
279
+ continue
280
+
281
+ # Skip if it contains advertisement keywords
282
+ s_lower = s_clean.lower()
283
+ if any(keyword in s_lower for keyword in ad_keywords):
284
+ continue
285
+
286
+ filtered_sentences.append(s_clean)
287
+
288
+ return filtered_sentences
289
+
290
+ def fetch_newsdata(self, query=None, category=None, language="en"):
291
+ if not self.newsdata_api_key: return []
292
+ url = "https://newsdata.io/api/1/news"
293
+ # Strictly enforce English
294
+ params = {"apikey": self.newsdata_api_key, "language": "en"}
295
+ if query: params["q"] = query
296
+ if category and category != 'general':
297
+ params["category"] = category
298
+ try:
299
+ res = self.session.get(url, params=params)
300
+ # print(f"DEBUG: NewsData.io API response status: {res.status_code}")
301
+ data = res.json()
302
+ # print(f"DEBUG: NewsData.io raw data: {str(data)[:500]}...") # Print first 500 chars
303
+ if data.get("status") == "success":
304
+ return [{
305
+ "title": r.get("title"),
306
+ "link": r.get("link"),
307
+ "source_id": r.get("source_id"),
308
+ "pubDate": r.get("pubDate"),
309
+ "image_url": r.get("image_url"),
310
+ "snippet": r.get("description") or r.get("content")
311
+ } for r in data.get("results", [])]
312
+ return []
313
+ except: return []
314
+
315
+ def fetch_guardian(self, query=None, category=None):
316
+ if not self.guardian_api_key or "your_" in self.guardian_api_key: return []
317
+ url = "https://content.guardianapis.com/search"
318
+ params = {"api-key": self.guardian_api_key, "show-fields": "thumbnail,trailText"}
319
+ if query:
320
+ params["q"] = query
321
+
322
+ # Map frontend categories to Guardian sections
323
+ category_map = {
324
+ 'business': 'business',
325
+ 'technology': 'technology',
326
+ 'entertainment': 'culture',
327
+ 'health': 'society',
328
+ 'science': 'science',
329
+ 'sports': 'sport'
330
+ }
331
+ if category and category in category_map:
332
+ params["section"] = category_map[category]
333
+
334
+ try:
335
+ res = self.session.get(url, params=params)
336
+ data = res.json()
337
+ results = data.get("response", {}).get("results", [])
338
+ return [{
339
+ "title": r.get("webTitle"),
340
+ "link": r.get("webUrl"),
341
+ "source_id": "The Guardian",
342
+ "pubDate": r.get("webPublicationDate"),
343
+ "image_url": r.get("fields", {}).get("thumbnail"),
344
+ "snippet": r.get("fields", {}).get("trailText")
345
+ } for r in results]
346
+ except: return []
347
+
348
+ def fetch_all_news(self, query=None, category=None, language="en"):
349
+ all_articles = []
350
+ all_articles.extend(self.fetch_newsdata(query, category, language))
351
+ all_articles.extend(self.fetch_guardian(query, category))
352
+
353
+ # Deduplicate articles based on title
354
+ unique_articles = []
355
+ seen_titles = set()
356
+ for article in all_articles:
357
+ title = article.get("title")
358
+ if title and title.lower() not in seen_titles:
359
+ unique_articles.append(article)
360
+ seen_titles.add(title.lower())
361
+
362
+ # Sort by publication date
363
+ try:
364
+ unique_articles.sort(key=lambda x: x.get('pubDate', ''), reverse=True)
365
+ except:
366
+ pass
367
+
368
+ # --- SCRAPABILITY FILTER ---
369
+ # We only want to show articles that we can actually scrape content from.
370
+ # Since scraping is slow, we use a thread pool and limit the number of articles we check.
371
+ filtered_articles = []
372
+ # Limit to top 15 most recent/relevant articles to ensure fast response time
373
+ articles_to_check = unique_articles[:15]
374
+
375
+ def check_article(article):
376
+ url = article.get("link")
377
+ if not url:
378
+ return None
379
+
380
+ # Simple check: can we get content?
381
+ content = self.get_full_content(url)
382
+ if content:
383
+ # Store content temporarily to avoid re-scraping if needed,
384
+ # but for now we just return the article if it's scrapable
385
+ return article
386
+ return None
387
+
388
+ with ThreadPoolExecutor(max_workers=5) as executor:
389
+ results = list(executor.map(check_article, articles_to_check))
390
+ filtered_articles = [r for r in results if r is not None]
391
+
392
+ # If we have very few articles after filtering, maybe check a few more
393
+ # but don't exceed the API timeout
394
+ if len(filtered_articles) < 5 and len(unique_articles) > 15:
395
+ extra_to_check = unique_articles[15:25]
396
+ with ThreadPoolExecutor(max_workers=5) as executor:
397
+ extra_results = list(executor.map(check_article, extra_to_check))
398
+ filtered_articles.extend([r for r in extra_results if r is not None])
399
+
400
+ return filtered_articles
401
+
402
+ def get_full_content(self, url):
403
+ try:
404
+ headers = {
405
+ 'User-Agent': self.config.browser_user_agent,
406
+ 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
407
+ 'Accept-Language': 'en-US,en;q=0.9',
408
+ 'Accept-Encoding': 'gzip, deflate, br',
409
+ 'Connection': 'keep-alive',
410
+ 'Upgrade-Insecure-Requests': '1'
411
+ }
412
+ response = self.session.get(url, headers=headers, timeout=self.config.request_timeout)
413
+ if response.status_code != 200:
414
+ return None
415
+
416
+ # Fast check: skip known paywall messages in raw HTML
417
+ html_lower = response.text.lower()
418
+ quick_fail_terms = ["subscribe to continue", "paywall", "premium access", "log in to read", "javascript is required"]
419
+ if any(term in html_lower for term in quick_fail_terms):
420
+ return None
421
+
422
+ article = Article(url, config=self.config)
423
+ article.set_html(response.text)
424
+ article.parse()
425
+
426
+ # Use raw extracted text, no truncation or AI here.
427
+ text = article.text.strip()
428
+
429
+ # --- STRICT SCANNING FILTER ---
430
+ # If text is too short, it's likely a paywall or failed extraction
431
+ if not text or len(text) < 400:
432
+ return None
433
+
434
+ system_errors = [
435
+ "javascript is required", "enable javascript", "allow cookies", "cookie policy",
436
+ "subscribe to continue", "log in to read", "premium access", "paywall",
437
+ "watch the video", "live updates", "photo gallery", "video-only",
438
+ "forbidden", "access denied", "403", "404"
439
+ ]
440
+ text_lower = text.lower()
441
+ if any(err in text_lower for err in system_errors):
442
+ return None
443
+
444
+ return text
445
+ except Exception as e:
446
+ return None
447
+
448
+ def summarize_content(self, text):
449
+ if not self.hf_client or not text or len(text.strip()) < 100:
450
+ return None
451
+ try:
452
+ truncated_text = text[:2000] # Reduced from 3500 to 2000
453
+ # Use a faster, smaller summarization model
454
+ response = self.hf_client.summarization(
455
+ truncated_text,
456
+ model="facebook/bart-large-cnn"
457
+ )
458
+
459
+ # Extract summary_text safely
460
+ if hasattr(response, 'summary_text'):
461
+ return response.summary_text
462
+ if isinstance(response, list) and len(response) > 0:
463
+ return response[0].get('summary_text') if isinstance(response[0], dict) else str(response[0])
464
+ if isinstance(response, dict):
465
+ return response.get('summary_text')
466
+
467
+ return str(response)
468
+ except Exception as e:
469
+ # Silently fail as requested
470
+ return None
471
+
472
+ if __name__ == "__main__":
473
+ service = NewsService()
474
+ test_text = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
475
+ print(f"Summary: {service.summarize_content(test_text)}")
NewsApex/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ requests
2
+ python-dotenv
3
+ newspaper3k
4
+ lxml_html_clean
5
+ huggingface_hub
6
+ streamlit
7
+ colorama
8
+ pandas
9
+ torch
10
+ transformers
11
+ datasets
12
+ scikit-learn
NewsApex/vercel.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "framework": "nextjs",
3
+ "installCommand": "npm install && pip install -r requirements.txt",
4
+ "buildCommand": "next build"
5
+ }
bart_rouge_test.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BartForConditionalGeneration, BartTokenizer
3
+ from rouge_score import rouge_scorer
4
+ import time
5
+
6
+ def summarize_with_bart_cnn(text):
7
+ """
8
+ Summarizes news articles using the facebook/bart-large-cnn pre-trained model.
9
+ """
10
+ model_name = "facebook/bart-large-cnn"
11
+ print(f"Loading {model_name}...")
12
+
13
+ # Load tokenizer and model
14
+ tokenizer = BartTokenizer.from_pretrained(model_name)
15
+ model = BartForConditionalGeneration.from_pretrained(model_name)
16
+
17
+ # Prepare input
18
+ inputs = tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
19
+
20
+ # Generate Summary
21
+ print("Generating summary...")
22
+ summary_ids = model.generate(
23
+ inputs["input_ids"],
24
+ num_beams=4,
25
+ max_length=150,
26
+ min_length=40,
27
+ early_stopping=True
28
+ )
29
+
30
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
31
+ return summary
32
+
33
+ def evaluate_summary(reference, generated):
34
+ """
35
+ Evaluates the generated summary against a reference summary using ROUGE scores.
36
+ """
37
+ print("\nCalculating ROUGE scores...")
38
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
39
+ scores = scorer.score(reference, generated)
40
+
41
+ print("\n" + "="*40)
42
+ print("ROUGE Evaluation Results:")
43
+ print("="*40)
44
+ for key, score in scores.items():
45
+ print(f"{key.upper()}:")
46
+ print(f" Precision: {score.precision:.4f}")
47
+ print(f" Recall: {score.recall:.4f}")
48
+ print(f" F1-Score: {score.fmeasure:.4f}")
49
+ print("="*40)
50
+
51
+ if __name__ == "__main__":
52
+ # Sample News Article (Source: BBC News)
53
+ sample_article = """
54
+ The UK's inflation rate has fallen to 1.7% in September, its lowest level in three and a half years.
55
+ The drop, which was larger than expected, was driven by lower airfares and petrol prices.
56
+ It is the first time the rate has fallen below the Bank of England's 2% target since April 2021.
57
+ Economists say the fall makes it more likely that the Bank of England will cut interest rates at its next meeting in November.
58
+ Official figures from the Office for National Statistics (ONS) showed that inflation, the rate at which prices rise, fell from 2.2% in August.
59
+ Lower airfares, which usually drop after the summer holidays, and lower fuel prices for drivers were the main factors behind the decrease.
60
+ However, prices for food and non-alcoholic beverages continued to rise, although at a slower pace than in previous months.
61
+ """
62
+
63
+ # Reference Summary (Human-written or Golden Standard)
64
+ reference_summary = """
65
+ UK inflation fell to 1.7% in September, the lowest level in over three years and below the Bank of England's 2% target.
66
+ The decrease was mainly due to lower airfares and fuel prices, increasing the likelihood of an interest rate cut in November.
67
+ """
68
+
69
+ print("--- BART-CNN Summarization and ROUGE Evaluation ---")
70
+ start_time = time.time()
71
+
72
+ # Generate summary
73
+ generated_summary = summarize_with_bart_cnn(sample_article)
74
+
75
+ print("\nOriginal Article Length:", len(sample_article.split()), "words")
76
+ print("Generated Summary Length:", len(generated_summary.split()), "words")
77
+ print("\nGenerated Summary:")
78
+ print("-" * 20)
79
+ print(generated_summary)
80
+ print("-" * 20)
81
+
82
+ # Evaluate summary
83
+ evaluate_summary(reference_summary, generated_summary)
84
+
85
+ end_time = time.time()
86
+ print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")
87
+ print("\nNOTE: To run this script, you must install rouge-score: pip install rouge-score")