File size: 4,717 Bytes
3b97420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

import argparse
import json
import random
import re
import unicodedata
from pathlib import Path

from datasets import load_dataset
from tqdm import tqdm


_WS = re.compile(r"\s+")
_BAD_CHARS = re.compile(r"[\u0000-\u001f]")
_REFS = re.compile(r"\[\s*\d+\s*\]")
# CNN/DailyMail articles often start with "(CNN) -- " or "By . SOMEBODY . PUBLISHED:"
_CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE)
_BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL)


PROMPT_TEMPLATES = [
    "Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n",
    "Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n",
    "Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n",
    "Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n",
]


def normalize(text: str) -> str:
    if text is None:
        return ""
    text = str(text)
    text = text.replace("\ufffd", " ")
    text = unicodedata.normalize("NFKC", text)
    text = _BAD_CHARS.sub(" ", text)
    text = _REFS.sub("", text)
    text = _CNN_PREFIX.sub("", text)
    text = _BYLINE.sub("", text)
    text = _WS.sub(" ", text).strip()
    return text


def join_highlights(highlights: str) -> str:
    """
    CNN/DailyMail highlights come as several short lines joined by newlines.
    We join them into a single multi-sentence string with periods.
    """
    if not highlights:
        return ""
    pieces = [p.strip() for p in highlights.split("\n") if p.strip()]
    # Make sure each piece ends with terminal punctuation.
    fixed = []
    for p in pieces:
        if p[-1] not in ".!?":
            p = p + "."
        fixed.append(p)
    return " ".join(fixed)


def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int,
                 min_summary_chars: int, max_summary_chars: int) -> bool:
    if not article or not summary:
        return False
    if not (min_article_chars <= len(article) <= max_article_chars):
        return False
    if not (min_summary_chars <= len(summary) <= max_summary_chars):
        return False
    # Reject if the summary is basically the whole article (rare here but safe).
    if len(summary) >= 0.8 * len(article):
        return False
    # Must be mostly letters.
    letters = sum(ch.isalpha() for ch in article)
    if letters / max(1, len(article)) < 0.6:
        return False
    return True


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Build an SFT set from CNN/DailyMail (near-extractive summaries)."
    )
    parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
    parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail")
    parser.add_argument("--config", type=str, default="3.0.0")
    parser.add_argument("--max_examples", type=int, default=100000)
    parser.add_argument("--min_article_chars", type=int, default=400)
    parser.add_argument("--max_article_chars", type=int, default=2200)
    parser.add_argument("--min_summary_chars", type=int, default=80)
    parser.add_argument("--max_summary_chars", type=int, default=400)
    parser.add_argument("--seed", type=int, default=1337)
    args = parser.parse_args()

    args.out_file.parent.mkdir(parents=True, exist_ok=True)
    rng = random.Random(args.seed)

    print(f"Loading {args.dataset} ({args.config})...")
    dataset = load_dataset(args.dataset, args.config, split="train")

    count = 0
    skipped = 0

    with args.out_file.open("w", encoding="utf-8") as f:
        for row in tqdm(dataset, desc="building SFT"):
            article = normalize(row.get("article", ""))
            summary = join_highlights(normalize(row.get("highlights", "")))

            if not is_good_pair(
                article, summary,
                args.min_article_chars, args.max_article_chars,
                args.min_summary_chars, args.max_summary_chars,
            ):
                skipped += 1
                continue

            if len(article) > args.max_article_chars:
                article = article[: args.max_article_chars].rsplit(" ", 1)[0]

            template = rng.choice(PROMPT_TEMPLATES)
            prompt = template.format(passage=article)
            f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n")
            count += 1
            if count >= args.max_examples:
                break

    print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})")


if __name__ == "__main__":
    main()