File size: 4,717 Bytes
3b97420 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | from __future__ import annotations
import argparse
import json
import random
import re
import unicodedata
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
_WS = re.compile(r"\s+")
_BAD_CHARS = re.compile(r"[\u0000-\u001f]")
_REFS = re.compile(r"\[\s*\d+\s*\]")
# CNN/DailyMail articles often start with "(CNN) -- " or "By . SOMEBODY . PUBLISHED:"
_CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE)
_BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL)
PROMPT_TEMPLATES = [
"Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n",
"Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n",
"Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n",
"Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n",
]
def normalize(text: str) -> str:
if text is None:
return ""
text = str(text)
text = text.replace("\ufffd", " ")
text = unicodedata.normalize("NFKC", text)
text = _BAD_CHARS.sub(" ", text)
text = _REFS.sub("", text)
text = _CNN_PREFIX.sub("", text)
text = _BYLINE.sub("", text)
text = _WS.sub(" ", text).strip()
return text
def join_highlights(highlights: str) -> str:
"""
CNN/DailyMail highlights come as several short lines joined by newlines.
We join them into a single multi-sentence string with periods.
"""
if not highlights:
return ""
pieces = [p.strip() for p in highlights.split("\n") if p.strip()]
# Make sure each piece ends with terminal punctuation.
fixed = []
for p in pieces:
if p[-1] not in ".!?":
p = p + "."
fixed.append(p)
return " ".join(fixed)
def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int,
min_summary_chars: int, max_summary_chars: int) -> bool:
if not article or not summary:
return False
if not (min_article_chars <= len(article) <= max_article_chars):
return False
if not (min_summary_chars <= len(summary) <= max_summary_chars):
return False
# Reject if the summary is basically the whole article (rare here but safe).
if len(summary) >= 0.8 * len(article):
return False
# Must be mostly letters.
letters = sum(ch.isalpha() for ch in article)
if letters / max(1, len(article)) < 0.6:
return False
return True
def main() -> None:
parser = argparse.ArgumentParser(
description="Build an SFT set from CNN/DailyMail (near-extractive summaries)."
)
parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail")
parser.add_argument("--config", type=str, default="3.0.0")
parser.add_argument("--max_examples", type=int, default=100000)
parser.add_argument("--min_article_chars", type=int, default=400)
parser.add_argument("--max_article_chars", type=int, default=2200)
parser.add_argument("--min_summary_chars", type=int, default=80)
parser.add_argument("--max_summary_chars", type=int, default=400)
parser.add_argument("--seed", type=int, default=1337)
args = parser.parse_args()
args.out_file.parent.mkdir(parents=True, exist_ok=True)
rng = random.Random(args.seed)
print(f"Loading {args.dataset} ({args.config})...")
dataset = load_dataset(args.dataset, args.config, split="train")
count = 0
skipped = 0
with args.out_file.open("w", encoding="utf-8") as f:
for row in tqdm(dataset, desc="building SFT"):
article = normalize(row.get("article", ""))
summary = join_highlights(normalize(row.get("highlights", "")))
if not is_good_pair(
article, summary,
args.min_article_chars, args.max_article_chars,
args.min_summary_chars, args.max_summary_chars,
):
skipped += 1
continue
if len(article) > args.max_article_chars:
article = article[: args.max_article_chars].rsplit(" ", 1)[0]
template = rng.choice(PROMPT_TEMPLATES)
prompt = template.format(passage=article)
f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n")
count += 1
if count >= args.max_examples:
break
print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})")
if __name__ == "__main__":
main() |