RON-110M / code /make_cnndm_sft.py
endurasolution's picture
Upload Ron-110M: pretrain + summarizer + tokenizer + code
3b97420 verified
from __future__ import annotations
import argparse
import json
import random
import re
import unicodedata
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
_WS = re.compile(r"\s+")
_BAD_CHARS = re.compile(r"[\u0000-\u001f]")
_REFS = re.compile(r"\[\s*\d+\s*\]")
# CNN/DailyMail articles often start with "(CNN) -- " or "By . SOMEBODY . PUBLISHED:"
_CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE)
_BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL)
PROMPT_TEMPLATES = [
"Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n",
"Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n",
"Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n",
"Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n",
]
def normalize(text: str) -> str:
if text is None:
return ""
text = str(text)
text = text.replace("\ufffd", " ")
text = unicodedata.normalize("NFKC", text)
text = _BAD_CHARS.sub(" ", text)
text = _REFS.sub("", text)
text = _CNN_PREFIX.sub("", text)
text = _BYLINE.sub("", text)
text = _WS.sub(" ", text).strip()
return text
def join_highlights(highlights: str) -> str:
"""
CNN/DailyMail highlights come as several short lines joined by newlines.
We join them into a single multi-sentence string with periods.
"""
if not highlights:
return ""
pieces = [p.strip() for p in highlights.split("\n") if p.strip()]
# Make sure each piece ends with terminal punctuation.
fixed = []
for p in pieces:
if p[-1] not in ".!?":
p = p + "."
fixed.append(p)
return " ".join(fixed)
def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int,
min_summary_chars: int, max_summary_chars: int) -> bool:
if not article or not summary:
return False
if not (min_article_chars <= len(article) <= max_article_chars):
return False
if not (min_summary_chars <= len(summary) <= max_summary_chars):
return False
# Reject if the summary is basically the whole article (rare here but safe).
if len(summary) >= 0.8 * len(article):
return False
# Must be mostly letters.
letters = sum(ch.isalpha() for ch in article)
if letters / max(1, len(article)) < 0.6:
return False
return True
def main() -> None:
parser = argparse.ArgumentParser(
description="Build an SFT set from CNN/DailyMail (near-extractive summaries)."
)
parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail")
parser.add_argument("--config", type=str, default="3.0.0")
parser.add_argument("--max_examples", type=int, default=100000)
parser.add_argument("--min_article_chars", type=int, default=400)
parser.add_argument("--max_article_chars", type=int, default=2200)
parser.add_argument("--min_summary_chars", type=int, default=80)
parser.add_argument("--max_summary_chars", type=int, default=400)
parser.add_argument("--seed", type=int, default=1337)
args = parser.parse_args()
args.out_file.parent.mkdir(parents=True, exist_ok=True)
rng = random.Random(args.seed)
print(f"Loading {args.dataset} ({args.config})...")
dataset = load_dataset(args.dataset, args.config, split="train")
count = 0
skipped = 0
with args.out_file.open("w", encoding="utf-8") as f:
for row in tqdm(dataset, desc="building SFT"):
article = normalize(row.get("article", ""))
summary = join_highlights(normalize(row.get("highlights", "")))
if not is_good_pair(
article, summary,
args.min_article_chars, args.max_article_chars,
args.min_summary_chars, args.max_summary_chars,
):
skipped += 1
continue
if len(article) > args.max_article_chars:
article = article[: args.max_article_chars].rsplit(" ", 1)[0]
template = rng.choice(PROMPT_TEMPLATES)
prompt = template.format(passage=article)
f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n")
count += 1
if count >= args.max_examples:
break
print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})")
if __name__ == "__main__":
main()