from __future__ import annotations import argparse import json import random import re import unicodedata from pathlib import Path from datasets import load_dataset from tqdm import tqdm _WS = re.compile(r"\s+") _BAD_CHARS = re.compile(r"[\u0000-\u001f]") _REFS = re.compile(r"\[\s*\d+\s*\]") # CNN/DailyMail articles often start with "(CNN) -- " or "By . SOMEBODY . PUBLISHED:" _CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE) _BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL) PROMPT_TEMPLATES = [ "Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n", "Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n", "Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n", "Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n", ] def normalize(text: str) -> str: if text is None: return "" text = str(text) text = text.replace("\ufffd", " ") text = unicodedata.normalize("NFKC", text) text = _BAD_CHARS.sub(" ", text) text = _REFS.sub("", text) text = _CNN_PREFIX.sub("", text) text = _BYLINE.sub("", text) text = _WS.sub(" ", text).strip() return text def join_highlights(highlights: str) -> str: """ CNN/DailyMail highlights come as several short lines joined by newlines. We join them into a single multi-sentence string with periods. """ if not highlights: return "" pieces = [p.strip() for p in highlights.split("\n") if p.strip()] # Make sure each piece ends with terminal punctuation. fixed = [] for p in pieces: if p[-1] not in ".!?": p = p + "." fixed.append(p) return " ".join(fixed) def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int, min_summary_chars: int, max_summary_chars: int) -> bool: if not article or not summary: return False if not (min_article_chars <= len(article) <= max_article_chars): return False if not (min_summary_chars <= len(summary) <= max_summary_chars): return False # Reject if the summary is basically the whole article (rare here but safe). if len(summary) >= 0.8 * len(article): return False # Must be mostly letters. letters = sum(ch.isalpha() for ch in article) if letters / max(1, len(article)) < 0.6: return False return True def main() -> None: parser = argparse.ArgumentParser( description="Build an SFT set from CNN/DailyMail (near-extractive summaries)." ) parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl")) parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail") parser.add_argument("--config", type=str, default="3.0.0") parser.add_argument("--max_examples", type=int, default=100000) parser.add_argument("--min_article_chars", type=int, default=400) parser.add_argument("--max_article_chars", type=int, default=2200) parser.add_argument("--min_summary_chars", type=int, default=80) parser.add_argument("--max_summary_chars", type=int, default=400) parser.add_argument("--seed", type=int, default=1337) args = parser.parse_args() args.out_file.parent.mkdir(parents=True, exist_ok=True) rng = random.Random(args.seed) print(f"Loading {args.dataset} ({args.config})...") dataset = load_dataset(args.dataset, args.config, split="train") count = 0 skipped = 0 with args.out_file.open("w", encoding="utf-8") as f: for row in tqdm(dataset, desc="building SFT"): article = normalize(row.get("article", "")) summary = join_highlights(normalize(row.get("highlights", ""))) if not is_good_pair( article, summary, args.min_article_chars, args.max_article_chars, args.min_summary_chars, args.max_summary_chars, ): skipped += 1 continue if len(article) > args.max_article_chars: article = article[: args.max_article_chars].rsplit(" ", 1)[0] template = rng.choice(PROMPT_TEMPLATES) prompt = template.format(passage=article) f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n") count += 1 if count >= args.max_examples: break print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})") if __name__ == "__main__": main()