| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import random |
| import re |
| import unicodedata |
| from pathlib import Path |
|
|
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
|
|
| _WS = re.compile(r"\s+") |
| _BAD_CHARS = re.compile(r"[\u0000-\u001f]") |
| _REFS = re.compile(r"\[\s*\d+\s*\]") |
| |
| _CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE) |
| _BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL) |
|
|
|
|
| PROMPT_TEMPLATES = [ |
| "Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n", |
| "Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n", |
| "Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n", |
| "Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n", |
| ] |
|
|
|
|
| def normalize(text: str) -> str: |
| if text is None: |
| return "" |
| text = str(text) |
| text = text.replace("\ufffd", " ") |
| text = unicodedata.normalize("NFKC", text) |
| text = _BAD_CHARS.sub(" ", text) |
| text = _REFS.sub("", text) |
| text = _CNN_PREFIX.sub("", text) |
| text = _BYLINE.sub("", text) |
| text = _WS.sub(" ", text).strip() |
| return text |
|
|
|
|
| def join_highlights(highlights: str) -> str: |
| """ |
| CNN/DailyMail highlights come as several short lines joined by newlines. |
| We join them into a single multi-sentence string with periods. |
| """ |
| if not highlights: |
| return "" |
| pieces = [p.strip() for p in highlights.split("\n") if p.strip()] |
| |
| fixed = [] |
| for p in pieces: |
| if p[-1] not in ".!?": |
| p = p + "." |
| fixed.append(p) |
| return " ".join(fixed) |
|
|
|
|
| def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int, |
| min_summary_chars: int, max_summary_chars: int) -> bool: |
| if not article or not summary: |
| return False |
| if not (min_article_chars <= len(article) <= max_article_chars): |
| return False |
| if not (min_summary_chars <= len(summary) <= max_summary_chars): |
| return False |
| |
| if len(summary) >= 0.8 * len(article): |
| return False |
| |
| letters = sum(ch.isalpha() for ch in article) |
| if letters / max(1, len(article)) < 0.6: |
| return False |
| return True |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Build an SFT set from CNN/DailyMail (near-extractive summaries)." |
| ) |
| parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl")) |
| parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail") |
| parser.add_argument("--config", type=str, default="3.0.0") |
| parser.add_argument("--max_examples", type=int, default=100000) |
| parser.add_argument("--min_article_chars", type=int, default=400) |
| parser.add_argument("--max_article_chars", type=int, default=2200) |
| parser.add_argument("--min_summary_chars", type=int, default=80) |
| parser.add_argument("--max_summary_chars", type=int, default=400) |
| parser.add_argument("--seed", type=int, default=1337) |
| args = parser.parse_args() |
|
|
| args.out_file.parent.mkdir(parents=True, exist_ok=True) |
| rng = random.Random(args.seed) |
|
|
| print(f"Loading {args.dataset} ({args.config})...") |
| dataset = load_dataset(args.dataset, args.config, split="train") |
|
|
| count = 0 |
| skipped = 0 |
|
|
| with args.out_file.open("w", encoding="utf-8") as f: |
| for row in tqdm(dataset, desc="building SFT"): |
| article = normalize(row.get("article", "")) |
| summary = join_highlights(normalize(row.get("highlights", ""))) |
|
|
| if not is_good_pair( |
| article, summary, |
| args.min_article_chars, args.max_article_chars, |
| args.min_summary_chars, args.max_summary_chars, |
| ): |
| skipped += 1 |
| continue |
|
|
| if len(article) > args.max_article_chars: |
| article = article[: args.max_article_chars].rsplit(" ", 1)[0] |
|
|
| template = rng.choice(PROMPT_TEMPLATES) |
| prompt = template.format(passage=article) |
| f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n") |
| count += 1 |
| if count >= args.max_examples: |
| break |
|
|
| print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |