| | |
| | |
| | """ |
| | chat_sprint_standalone.py |
| | One-file pipeline: collect datasets -> reformat as You:/Bot: -> train tiny GPT (CUDA) -> sample & save model |
| | |
| | Requirements (install once): |
| | uv pip install torch datasets sentencepiece tqdm numpy |
| | |
| | Run: |
| | python chat_sprint_standalone.py |
| | """ |
| |
|
| | import os, re, time, math, random, json |
| | from pathlib import Path |
| | from typing import List, Optional, Tuple |
| | from itertools import islice |
| | from contextlib import nullcontext |
| |
|
| | import numpy as np |
| | from tqdm import tqdm |
| | from datasets import load_dataset, get_dataset_config_names |
| | import sentencepiece as spm |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| | |
| | |
| | SEED = 1337 |
| | random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) |
| |
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | if DEVICE.type == "cuda": |
| | torch.set_float32_matmul_precision("high") |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | SAVE_DIR = Path("./chat_sprint_artifacts") |
| | SAVE_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | CAPS = { |
| | |
| | "oasst1": 0, |
| | "hhrlhf": 12000, |
| | "ultrachat": 20000, |
| | "dailydialog": 8000, |
| | "bst": 6000, |
| | "personachat": 8000, |
| | "soda": 50000, |
| | "topical": 8000, |
| |
|
| | |
| | "shakespeare": 15000, |
| | "jokes": 20000, |
| | "dadjokes": 8000, |
| | "rsarcasm": 8000, |
| | "figlang": 12000, |
| | "shower": 4000, |
| | "personas": 2000, |
| | "tweeteval": 4000, |
| | "fourchan": 500, |
| | "elonvtrump": 3000, |
| | } |
| |
|
| | |
| | SCAN = { |
| | |
| | "oasst1": 200_000, |
| | "hhrlhf": 100_000, |
| | "ultrachat": 220_000, |
| | "dailydialog": 30_000, |
| | "bst": 30_000, |
| | "personachat": 60_000, |
| | "soda": 300_000, |
| | "topical": 15_000, |
| |
|
| | |
| | "jokes": 120_000, |
| | "dadjokes": 60_000, |
| | "rsarcasm": 120_000, |
| | "figlang": 150_000, |
| | "shower": 250_000, |
| | "personas": 30_000, |
| | "tweeteval": 60_000, |
| | "fourchan": 2_000, |
| | "elonvtrump": 60_000, |
| | } |
| |
|
| | MAX_TOTAL_PAIRS = 150_000 |
| | MAX_LEN = 120 |
| |
|
| | |
| | VOCAB_SIZE = 5000 |
| | TOKENIZER_PREFIX = SAVE_DIR / "spm_chat" |
| | USER_SYMBOLS = ["You:", "Bot:", "[STYLE=Snark]", "[FORM=TWEET]", "[FORM=HEADLINE]", "[MOOD=Unhinged]", "[MOOD=Cheeky]"] |
| |
|
| | |
| | block_size = 256 |
| | n_layer = 6 |
| | n_head = 6 |
| | n_embd = 384 |
| | dropout = 0.0 |
| |
|
| | MAX_SECONDS = 300 |
| | train_steps = 5000 |
| | log_interval = 100 |
| | eval_every = 400 |
| | batch_size = 24 |
| | accum_steps = 3 |
| | base_lr = 3e-3 |
| | min_lr = 5e-4 |
| | warmup_ratio = 0.06 |
| |
|
| | |
| | TEMP = 0.8 |
| | TOP_K = 60 |
| | TOP_P = 0.95 |
| | REP_PEN = 1.08 |
| |
|
| | |
| | |
| | |
| | URL_RE = re.compile(r"https?://\S+|www\.\S+", re.IGNORECASE) |
| | MENT_RE = re.compile(r"@\w+") |
| | WS_RE = re.compile(r"\s+") |
| | QUOTE_RE = re.compile(r"^[\"'“”‘’]+|[\"'“”‘’]+$") |
| |
|
| | def clean_text(s: str) -> str: |
| | s = s.strip() |
| | s = URL_RE.sub("", s) |
| | s = MENT_RE.sub("", s) |
| | s = QUOTE_RE.sub("", s) |
| | s = WS_RE.sub(" ", s) |
| | return s.strip() |
| |
|
| | def shorten_to(s: str, n: int) -> str: |
| | s = re.sub(r"\s+", " ", s.strip()) |
| | if len(s) <= n: return s |
| | cut = max(s.rfind(". ", 0, n), s.rfind("! ", 0, n), s.rfind("? ", 0, n)) |
| | if cut != -1: return s[:cut+1].strip() |
| | return s[:n].strip() |
| |
|
| | def keep_or_clip(s: str, min_len: int = 6, max_len: int = MAX_LEN) -> Optional[str]: |
| | if not s: return None |
| | s = re.sub(r"\s+", " ", s.strip()) |
| | if len(s) < min_len: return None |
| | return shorten_to(s, max_len) |
| |
|
| | def turn(you: str, bot: str, tags: str = "") -> str: |
| | lines = [f"You: {you}".rstrip()] |
| | if tags: lines.append(tags) |
| | lines.append(f"Bot: {bot}".rstrip()) |
| | lines.append("") |
| | return "\n".join(lines) |
| |
|
| | def limited(ds, limit: int): |
| | try: |
| | return ds.take(limit) |
| | except Exception: |
| | return islice(ds, limit) |
| |
|
| | def get_first_nonempty(ex, keys) -> Optional[str]: |
| | for k in keys: |
| | v = ex.get(k) |
| | if isinstance(v, str) and v.strip(): |
| | return v |
| | return None |
| |
|
| | def to_str(x) -> Optional[str]: |
| | if isinstance(x, str): return x |
| | if isinstance(x, dict): |
| | for k in ("text","utterance","content","response","value","message","msg"): |
| | v = x.get(k) |
| | if isinstance(v, str) and v.strip(): |
| | return v |
| | return None |
| |
|
| | |
| | |
| | |
| | def collect_oasst1(pairs: List[str], overall: tqdm): |
| | """OpenAssistant/oasst1: pair prompter -> assistant using parent links.""" |
| | try: |
| | ds = load_dataset("OpenAssistant/oasst1", split="train", streaming=True) |
| | keep_cap, scan_cap = CAPS["oasst1"], SCAN["oasst1"] |
| | pbar = tqdm(total=keep_cap, desc="[oasst1]", unit="pair", leave=False, ncols=100) |
| | kept = scanned = 0 |
| | seen = {} |
| | for ex in limited(ds, scan_cap): |
| | scanned += 1 |
| | role = str(ex.get("role") or "") |
| | txt = keep_or_clip(clean_text(str(ex.get("text") or ""))) |
| | mid = ex.get("message_id") or ex.get("id") |
| | pid = ex.get("parent_id") |
| | if role == "assistant" and pid in seen and seen[pid][0] == "prompter": |
| | you = seen[pid][1] |
| | if you and txt: |
| | pairs.append(turn(you, txt)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | if mid and role in ("assistant", "prompter"): |
| | seen[mid] = (role, txt or "") |
| | pbar.close() |
| | print(f"[ok] oasst1 kept={kept} (scanned {scanned})") |
| | except Exception as e: |
| | print(f"[skip] oasst1: {e}") |
| |
|
| | def collect_ultrachat(pairs: List[str], overall: tqdm): |
| | """HuggingFaceH4/ultrachat_200k: slide over user->assistant pairs.""" |
| | try: |
| | ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft", streaming=True) |
| | keep_cap, scan_cap = CAPS["ultrachat"], SCAN["ultrachat"] |
| | pbar = tqdm(total=keep_cap, desc="[ultrachat]", unit="pair", leave=False, ncols=100) |
| | kept = scanned = 0 |
| | for ex in limited(ds, scan_cap): |
| | scanned += 1 |
| | msgs = ex.get("messages") or [] |
| | for a, b in zip(msgs, msgs[1:]): |
| | if (a.get("role") in ("user","human")) and (b.get("role") in ("assistant","gpt")): |
| | you = keep_or_clip(clean_text(str(a.get("content") or ""))) |
| | bot = keep_or_clip(clean_text(str(b.get("content") or ""))) |
| | if you and bot: |
| | pairs.append(turn(you, bot)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | print(f"[ok] ultrachat kept={kept} (scanned {scanned})") |
| | except Exception as e: |
| | print(f"[skip] ultrachat: {e}") |
| |
|
| |
|
| | def collect_dailydialog(pairs: List[str], overall: tqdm): |
| | """ |
| | DailyDialog via data-only mirrors (no dataset scripts). |
| | Tries Parquet-converted branches/configs first; collects adjacent utterance pairs. |
| | """ |
| | keep_cap, scan_cap = CAPS["dailydialog"], SCAN["dailydialog"] |
| | candidates = [ |
| | |
| | ("roskoN/dailydialog", "full", "refs/convert/parquet"), |
| | ("roskoN/dailydialog", "default", "refs/convert/parquet"), |
| | ("frankdarkluo/DailyDialog", "default", "refs/convert/parquet"), |
| | ("ConvLab/dailydialog", None, None), |
| | ] |
| |
|
| | kept = 0 |
| | for dsid, name, rev in candidates: |
| | try: |
| | |
| | load_kwargs = dict(split="train", streaming=True) |
| | if rev is not None: |
| | load_kwargs["revision"] = rev |
| | if name is not None: |
| | ds = load_dataset(dsid, name, **load_kwargs) |
| | else: |
| | ds = load_dataset(dsid, **load_kwargs) |
| |
|
| | pbar = tqdm(total=keep_cap, desc=f"[dailydialog:{dsid}]", unit="pair", leave=False, ncols=100) |
| | for ex in limited(ds, scan_cap): |
| | |
| | dialog = ( |
| | ex.get("dialog") or ex.get("dialogue") or ex.get("utterances") |
| | or ex.get("turns") or ex.get("content") or [] |
| | ) |
| | seq = [] |
| | if isinstance(dialog, list): |
| | for u in dialog: |
| | s = to_str(u) if not isinstance(u, str) else u |
| | s = keep_or_clip(clean_text(str(s or ""))) |
| | if s: |
| | seq.append(s) |
| |
|
| | for a, b in zip(seq, seq[1:]): |
| | pairs.append(turn(a, b)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: |
| | break |
| | if kept >= keep_cap: |
| | break |
| | pbar.close() |
| |
|
| | if kept > 0: |
| | print(f"[ok] dailydialog kept={kept} via {dsid} ({'name='+name if name else 'no-name'} {rev or ''})") |
| | return |
| | else: |
| | print(f"[try next] dailydialog: {dsid} produced 0 pairs; trying next candidate…") |
| | except Exception as e: |
| | print(f"[try next] dailydialog {dsid}: {e}") |
| |
|
| | print("[skip] dailydialog: no usable Parquet/JSON mirror found (all candidates failed)") |
| |
|
| |
|
| | def collect_bst(pairs: List[str], overall: tqdm): |
| | """blended_skill_talk: previous_utterance -> one of guided/free/suggested.""" |
| | try: |
| | ds = load_dataset("blended_skill_talk", split="train", streaming=True) |
| | keep_cap, scan_cap = CAPS["bst"], SCAN["bst"] |
| | pbar = tqdm(total=keep_cap, desc="[bst]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| | def pick_first(x): |
| | if isinstance(x, list) and x: return x[0] |
| | if isinstance(x, dict): |
| | for v in x.values(): |
| | if isinstance(v, list) and v: return v[0] |
| | return None |
| | for ex in limited(ds, scan_cap): |
| | you = keep_or_clip(clean_text(str(ex.get("previous_utterance") or ex.get("context") or ""))) |
| | cand = pick_first(ex.get("guided_messages")) or pick_first(ex.get("free_messages")) or pick_first(ex.get("suggestions")) |
| | bot = keep_or_clip(clean_text(str(cand or ""))) |
| | if you and bot: |
| | pairs.append(turn(you, bot)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | print(f"[ok] bst kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] bst: {e}") |
| |
|
| | def collect_personachat(pairs: List[str], overall: tqdm): |
| | """bavard/personachat_truecased (Parquet): expand history[] into adjacent pairs.""" |
| | try: |
| | ds = load_dataset("bavard/personachat_truecased", split="train", streaming=True, revision="refs/convert/parquet") |
| | keep_cap, scan_cap = CAPS["personachat"], SCAN["personachat"] |
| | pbar = tqdm(total=keep_cap, desc="[personachat]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| | for ex in limited(ds, scan_cap): |
| | hist = ex.get("history") or [] |
| | seq = [keep_or_clip(clean_text(str(u))) for u in hist] |
| | seq = [s for s in seq if s] |
| | for a, b in zip(seq, seq[1:]): |
| | pairs.append(turn(a, b)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | print(f"[ok] personachat kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] personachat: {e}") |
| |
|
| |
|
| | def collect_soda(pairs: List[str], overall: tqdm): |
| | """allenai/soda: dialogue can be list[str] or list[dict].""" |
| | try: |
| | ds = load_dataset("allenai/soda", split="train", streaming=True) |
| | keep_cap, scan_cap = CAPS["soda"], SCAN["soda"] |
| | pbar = tqdm(total=keep_cap, desc="[soda]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| | for ex in limited(ds, scan_cap): |
| | dia = ex.get("dialogue") or ex.get("dialog") or ex.get("utterances") or [] |
| | seq = [] |
| | for u in dia: |
| | s = to_str(u) |
| | s = keep_or_clip(clean_text(str(s or ""))) |
| | if s: seq.append(s) |
| | for a, b in zip(seq, seq[1:]): |
| | pairs.append(turn(a, b)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | print(f"[ok] soda kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] soda: {e}") |
| |
|
| | def collect_topical_chat(pairs: List[str], overall: tqdm): |
| | """ |
| | Topical-Chat: try Hub mirrors, else fetch official GitHub JSON files. |
| | Keeps only conversation text (no reading sets). |
| | """ |
| | keep_cap, scan_cap = CAPS["topical"], SCAN["topical"] |
| | kept = 0 |
| |
|
| | |
| | dsids = [ |
| | "Conversational-Reasoning/Topical-Chat", |
| | "AmazonScience/Topical-Chat", |
| | "microsoft/Topical-Chat", |
| | ] |
| | splits = ["train","valid_freq","valid_rare","test_freq","test_rare","validation","test","valid_frequent","test_frequent"] |
| |
|
| | for dsid in dsids: |
| | try: |
| | pbar = tqdm(total=keep_cap, desc=f"[topical:{dsid}]", unit="pair", leave=False, ncols=100) |
| | for split in splits: |
| | try: |
| | ds = load_dataset(dsid, split=split, streaming=True) |
| | except Exception: |
| | continue |
| | for ex in limited(ds, scan_cap): |
| | dia = ex.get("messages") or ex.get("conversation") or ex.get("utterances") or ex.get("dialogue") or ex.get("content") or [] |
| | seq = [] |
| | for u in dia: |
| | s = to_str(u) |
| | s = keep_or_clip(clean_text(str(s or ""))) |
| | if s: seq.append(s) |
| | for a, b in zip(seq, seq[1:]): |
| | pairs.append(turn(a, b)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | if kept > 0: |
| | print(f"[ok] topical_chat ({dsid}) kept={kept}") |
| | return |
| | except Exception as e: |
| | print(f"[try next] topical_chat {dsid}: {e}") |
| |
|
| | |
| | try: |
| | names = ["train","valid_freq","valid_rare","test_freq","test_rare"] |
| | base = "https://raw.githubusercontent.com/alexa/Topical-Chat/master/conversations" |
| | pbar = tqdm(total=keep_cap, desc="[topical:github]", unit="pair", leave=False, ncols=100) |
| | for nm in names: |
| | try: |
| | with urlreq.urlopen(f"{base}/{nm}.json") as r: |
| | data = json.loads(r.read().decode("utf-8")) |
| | except Exception: |
| | continue |
| | for _, convo in data.items(): |
| | content = convo.get("content") or [] |
| | seq = [] |
| | for t in content: |
| | s = keep_or_clip(clean_text(str(t.get("message") or t.get("text") or ""))) |
| | if s: seq.append(s) |
| | for a, b in zip(seq, seq[1:]): |
| | pairs.append(turn(a, b)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | if kept == 0: |
| | print("[skip] topical_chat: no usable conversations found") |
| | else: |
| | print(f"[ok] topical_chat (github) kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] topical_chat (github): {e}") |
| |
|
| |
|
| | |
| | |
| | |
| | def collect_shakespeare(pairs: List[str], overall: tqdm): |
| | try: |
| | ds = load_dataset( |
| | "text", |
| | data_files={"train": "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"}, |
| | split="train", |
| | streaming=True, |
| | ) |
| | kept = 0 |
| | pbar = tqdm(total=CAPS["shakespeare"], desc="[shakespeare]", unit="pair", leave=False, ncols=100) |
| | for row in ds: |
| | txt = keep_or_clip(clean_text(row["text"])) |
| | if not txt: continue |
| | pairs.append(turn("Continue in Shakespearean style.", txt)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= CAPS["shakespeare"]: break |
| | pbar.close() |
| | print(f"[ok] shakespeare kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] shakespeare: {e}") |
| |
|
| | def collect_reddit_jokes(pairs: List[str], overall: tqdm): |
| | for dsid in ["SocialGrep/one-million-reddit-jokes", "SocialGrep/reddit_jokes", "timc1/reddit_jokes"]: |
| | try: |
| | ds = load_dataset(dsid, split="train", streaming=True) |
| | kept = 0 |
| | pbar = tqdm(total=CAPS["jokes"], desc="[jokes]", unit="pair", leave=False, ncols=100) |
| | for ex in limited(ds, SCAN["jokes"]): |
| | title = keep_or_clip(clean_text(str(ex.get("title") or ""))) |
| | body = keep_or_clip(clean_text(str(ex.get("selftext") or ex.get("body") or ""))) |
| | if body and title: |
| | pairs.append(turn(title, body)) |
| | elif title: |
| | pairs.append(turn("Tell me a short joke.", title)) |
| | else: |
| | continue |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= CAPS["jokes"]: break |
| | pbar.close() |
| | print(f"[ok] jokes {dsid} kept={kept}") |
| | return |
| | except Exception as e: |
| | print(f"[try next] jokes {dsid}: {e}") |
| | print("[skip] jokes: none worked") |
| |
|
| | def collect_dadjokes(pairs: List[str], overall: tqdm): |
| | try: |
| | ds = load_dataset("shuttie/reddit-dadjokes", split="train", streaming=True) |
| | kept = 0 |
| | pbar = tqdm(total=CAPS["dadjokes"], desc="[dadjokes]", unit="pair", leave=False, ncols=100) |
| | for ex in limited(ds, SCAN["dadjokes"]): |
| | setup = keep_or_clip(clean_text(str(ex.get("setup") or ex.get("instruction") or ex.get("input") or ""))) |
| | punch = keep_or_clip(clean_text(str(ex.get("punchline") or ex.get("output") or ""))) |
| | if not (setup and punch): continue |
| | pairs.append(turn(setup, punch)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= CAPS["dadjokes"]: break |
| | pbar.close() |
| | print(f"[ok] dadjokes kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] dadjokes: {e}") |
| |
|
| | def collect_reddit_sarcasm(pairs: List[str], overall: tqdm): |
| | """Thewillonline/reddit-sarcasm — flexible parsing, scan+keep bars.""" |
| | try: |
| | ds = load_dataset("Thewillonline/reddit-sarcasm", split="train", streaming=True) |
| | keep_cap, scan_cap = CAPS["rsarcasm"], SCAN["rsarcasm"] |
| | scanbar = tqdm(total=scan_cap, desc="[sarcasm scan]", unit="row", leave=False, ncols=100) |
| | keepbar = tqdm(total=keep_cap, desc="[sarcasm kept]", unit="pair", leave=False, ncols=100) |
| |
|
| | PATS = [ |
| | re.compile(r"User\s*:\s*(.+?)\s*(?:Reddit\s*Comment|Comment|Reply)\s*:\s*(.+)", re.IGNORECASE | re.DOTALL), |
| | re.compile(r"Post\s*:\s*(.+?)\s*(?:Top\s*Comment|Comment)\s*:\s*(.+)", re.IGNORECASE | re.DOTALL), |
| | ] |
| | def parse(raw: str) -> Tuple[Optional[str], Optional[str]]: |
| | raw = raw.replace("<|endoftext|>", "\n") |
| | for pat in PATS: |
| | m = pat.search(raw) |
| | if m: |
| | return m.group(1).strip(), m.group(2).strip() |
| | lines = [ln.strip() for ln in raw.splitlines() if ln.strip()] |
| | if len(lines) >= 2: return lines[0], lines[1] |
| | if len(lines) == 1: return "Reply with sarcasm:", lines[0] |
| | return None, None |
| |
|
| | kept = scanned = 0 |
| | for ex in limited(ds, scan_cap): |
| | scanned += 1 |
| | you, bot = parse(str(ex.get("text") or "")) |
| | you = keep_or_clip(you); bot = keep_or_clip(bot) |
| | if you and bot: |
| | pairs.append(turn(you, bot, "[STYLE=Snark]")) |
| | kept += 1; keepbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | scanbar.update(1) |
| | if scanned % 2000 == 0: |
| | keepbar.set_postfix(rate=f"{kept/max(1,scanned):.2%}") |
| | scanbar.close(); keepbar.close() |
| | print(f"[ok] reddit-sarcasm kept={kept} (scanned {scanned})") |
| | except Exception as e: |
| | print(f"[skip] reddit-sarcasm: {e}") |
| |
|
| | def collect_figlang(pairs: List[str], overall: tqdm): |
| | for dsid in ["tasksource/figlang2020-sarcasm", "tasksource/figlang2020_sarcasm"]: |
| | try: |
| | ds = load_dataset(dsid, split="train", streaming=True) |
| | kept = 0 |
| | pbar = tqdm(total=CAPS["figlang"], desc="[figlang]", unit="pair", leave=False, ncols=100) |
| | for ex in limited(ds, SCAN["figlang"]): |
| | ctx = ex.get("context") |
| | if isinstance(ctx, list) and ctx: |
| | context_str = " ".join(str(c) for c in ctx[-2:]) |
| | else: |
| | context_str = str(ex.get("context") or ex.get("prompt") or "") |
| | reply = str(ex.get("response") or ex.get("answer") or ex.get("text") or "") |
| | context_str = keep_or_clip(clean_text(context_str)) |
| | reply = keep_or_clip(clean_text(reply)) |
| | if reply: |
| | if context_str: |
| | pairs.append(turn(context_str, reply, "[STYLE=Snark]")) |
| | else: |
| | pairs.append(turn("Reply with sarcasm:", reply, "[STYLE=Snark]")) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= CAPS["figlang"]: break |
| | pbar.close() |
| | print(f"[ok] figlang {dsid} kept={kept}") |
| | return |
| | except Exception as e: |
| | print(f"[try next] figlang {dsid}: {e}") |
| | print("[skip] figlang") |
| |
|
| | def collect_showerthoughts(pairs: List[str], overall: tqdm): |
| | """Use REEDIT_submissions split 'Showerthoughts' directly (no 'train').""" |
| | try: |
| | ds = load_dataset("HuggingFaceGECLM/REDDIT_submissions", split="Showerthoughts", streaming=True) |
| | keep_cap, scan_cap = CAPS["shower"], SCAN["shower"] |
| | scanbar = tqdm(total=scan_cap, desc="[shower scan]", unit="row", leave=False, ncols=100) |
| | keepbar = tqdm(total=keep_cap, desc="[shower kept]", unit="pair", leave=False, ncols=100) |
| |
|
| | kept = scanned = 0 |
| | for ex in limited(ds, scan_cap): |
| | scanned += 1 |
| | title = get_first_nonempty(ex, ["title", "selftext", "text"]) or "" |
| | text = keep_or_clip(clean_text(title)) |
| | if text: |
| | pairs.append(turn("Give me a shower thought.", text)) |
| | kept += 1; keepbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | scanbar.update(1) |
| | scanbar.close(); keepbar.close() |
| | print(f"[ok] showerthoughts kept={kept} (scanned {scanned})") |
| | except Exception as e: |
| | print(f"[skip] showerthoughts: {e}") |
| |
|
| | def collect_personas(pairs: List[str], overall: tqdm): |
| | """Non-streaming is more reliable for this dataset.""" |
| | try: |
| | ds = load_dataset("NapthaAI/twitter_personas")["train"] |
| | keep_cap = CAPS["personas"] |
| | pbar = tqdm(total=keep_cap, desc="[personas]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| | for ex in ds: |
| | desc = get_first_nonempty(ex, ["description","persona","bio","text","content","full_text"]) |
| | if not isinstance(desc, str) and isinstance(ex.get("content"), dict): |
| | desc = ex["content"].get("text") |
| | desc = keep_or_clip(clean_text(str(desc or ""))) |
| | if not desc: continue |
| | pairs.append(turn("Adopt this persona in one sentence.", desc, "[FORM=TWEET]")) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | print(f"[ok] personas kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] personas: {e}") |
| |
|
| | def collect_tweeteval(pairs: List[str], overall: tqdm): |
| | """Handle super_tweeteval (text_1/text_2, etc.) and fallback tweet_eval.""" |
| | def extract_pair(ex): |
| | t = ex.get("text") |
| | if isinstance(t, str) and t.strip(): |
| | return "React with a sharp one-liner.", t |
| | for a,b in [("text_1","text_2"), ("sentence1","sentence2"), |
| | ("premise","hypothesis"), ("question","answer"), |
| | ("context","response"), ("tweet1","tweet2")]: |
| | t1, t2 = ex.get(a), ex.get(b) |
| | if isinstance(t1, str) and t1.strip() and isinstance(t2, str) and t2.strip(): |
| | return t1, t2 |
| | return None |
| |
|
| | def run_on(dsname, pick, is_super): |
| | keep_cap, scan_cap = CAPS["tweeteval"], SCAN["tweeteval"] |
| | pbar = tqdm(total=keep_cap, desc=f"[tweeteval:{pick}]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| | ds = load_dataset(dsname, pick, split="train", streaming=True) |
| | for ex in limited(ds, scan_cap): |
| | pair = extract_pair(ex) if is_super else ("React with a sharp one-liner.", ex.get("text")) if ex.get("text") else None |
| | if not pair: continue |
| | you, bot = pair |
| | you = keep_or_clip(clean_text(str(you or ""))); bot = keep_or_clip(clean_text(str(bot or ""))) |
| | if not (you and bot): continue |
| | tag = "[STYLE=Snark]" if you and you != "React with a sharp one-liner." else "" |
| | pairs.append(turn(you, bot, tag)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | return kept |
| |
|
| | kept_total = 0 |
| | try: |
| | cfgs = get_dataset_config_names("cardiffnlp/super_tweeteval") |
| | except Exception: |
| | cfgs = [] |
| | prio = ["irony","sarcasm","humor","sentiment","emoji","emotion","stance","offensive","hate"] |
| | ordered = [c for c in prio if c in cfgs] + [c for c in cfgs if c not in prio] |
| | for pick in ordered: |
| | kept_total += run_on("cardiffnlp/super_tweeteval", pick, True) |
| | if kept_total >= CAPS["tweeteval"]: |
| | print(f"[ok] tweeteval(super) kept={kept_total}"); return |
| | if kept_total == 0: |
| | try: |
| | base_cfgs = get_dataset_config_names("cardiffnlp/tweet_eval") |
| | except Exception: |
| | base_cfgs = [] |
| | ordered_b = [c for c in prio if c in base_cfgs] + [c for c in base_cfgs if c not in prio] |
| | for pick in ordered_b: |
| | kept_total += run_on("cardiffnlp/tweet_eval", pick, False) |
| | if kept_total >= CAPS["tweeteval"]: |
| | print(f"[ok] tweeteval(base) kept={kept_total}"); return |
| | print(f"[ok] tweeteval kept={kept_total}") |
| |
|
| | def collect_fourchan(pairs: List[str], overall: tqdm): |
| | try: |
| | ds = load_dataset("sbussiso/4chan-dataset", split="train", streaming=True) |
| | keep_cap = min(CAPS["fourchan"], 195) |
| | pbar = tqdm(total=keep_cap, desc="[4chan]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| | for ex in limited(ds, SCAN["fourchan"]): |
| | prompt = keep_or_clip(clean_text(str(ex.get("prompt") or ""))) |
| | resp = keep_or_clip(clean_text(str(ex.get("response") or ""))) |
| | if prompt and resp: |
| | pairs.append(turn(prompt, resp)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | else: |
| | txt = keep_or_clip(clean_text(str(ex.get("text") or ex.get("body") or ex.get("content") or ""))) |
| | if txt: |
| | pairs.append(turn("Drop a spicy one-liner.", txt)) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | pbar.close() |
| | print(f"[ok] 4chan kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] 4chan: {e}") |
| |
|
| | def collect_elon_trump(pairs: List[str], overall: tqdm): |
| | try: |
| | ds = load_dataset("MasaFoundation/Twitter_X_Elon_vs_Trump", split="train", streaming=True, revision="refs/convert/parquet") |
| | keep_cap, scan_cap = CAPS["elonvtrump"], SCAN["elonvtrump"] |
| | scanbar = tqdm(total=scan_cap, desc="[elon_vs_trump scan]", unit="row", leave=False, ncols=100) |
| | keepbar = tqdm(total=keep_cap, desc="[elon_vs_trump kept]", unit="pair", leave=False, ncols=100) |
| | kept = scanned = 0 |
| | for ex in limited(ds, scan_cap): |
| | scanned += 1 |
| | txt = get_first_nonempty(ex, ["content","text","tweet","full_text"]) or "" |
| | txt = keep_or_clip(clean_text(txt)) |
| | if txt: |
| | pairs.append(turn("[FORM=TWEET] One sentence hot take:", txt, "[FORM=TWEET]")) |
| | kept += 1; keepbar.update(1); overall.update(1) |
| | if kept >= keep_cap: break |
| | scanbar.update(1) |
| | scanbar.close(); keepbar.close() |
| | print(f"[ok] Elon_vs_Trump kept={kept} (scanned {scanned})") |
| | except Exception as e: |
| | print(f"[skip] Elon_vs_Trump: {e}") |
| | |
| | |
| | def collect_hh_rlhf(pairs: List[str], overall: tqdm): |
| | """ |
| | Anthropic HH-RLHF (English multi-turn). We parse the 'chosen' conversation |
| | into adjacent Human->Assistant pairs and emit You:/Bot: blocks. |
| | """ |
| | try: |
| | ds = load_dataset("Anthropic/hh-rlhf", split="train", streaming=True) |
| | keep_cap, scan_cap = CAPS["hhrlhf"], SCAN["hhrlhf"] |
| | pbar = tqdm(total=keep_cap, desc="[hh-rlhf]", unit="pair", leave=False, ncols=100) |
| | kept = 0 |
| |
|
| | HUMAN = re.compile(r"(?:^|\n)\s*Human:\s*", re.IGNORECASE) |
| | ASSIST = re.compile(r"(?:^|\n)\s*Assistant:\s*", re.IGNORECASE) |
| |
|
| | for ex in limited(ds, scan_cap): |
| | |
| | convo = ex.get("chosen") or ex.get("prompt") or "" |
| | if not isinstance(convo, str) or not convo.strip(): |
| | continue |
| |
|
| | |
| | |
| | tokens = re.split(r"(Human:|Assistant:)", convo) |
| | turns = [] |
| | |
| | for i in range(1, len(tokens), 2): |
| | role = tokens[i].strip().lower() |
| | content = tokens[i + 1] if i + 1 < len(tokens) else "" |
| | content = keep_or_clip(clean_text(str(content or ""))) |
| | if not content: |
| | continue |
| | if role.startswith("human"): |
| | turns.append(("user", content)) |
| | elif role.startswith("assistant"): |
| | turns.append(("assistant", content)) |
| |
|
| | |
| | for a, b in zip(turns, turns[1:]): |
| | if a[0] == "user" and b[0] == "assistant": |
| | pairs.append(turn(a[1], b[1])) |
| | kept += 1; pbar.update(1); overall.update(1) |
| | if kept >= keep_cap: |
| | break |
| | if kept >= keep_cap: |
| | break |
| |
|
| | pbar.close() |
| | print(f"[ok] hh-rlhf kept={kept}") |
| | except Exception as e: |
| | print(f"[skip] hh-rlhf: {e}") |
| | |
| |
|
| | |
| | |
| | |
| | def build_corpus() -> Path: |
| | pairs: List[str] = [] |
| | total_target = sum(CAPS.values()) |
| | print("[1/6] Collecting & reformatting datasets (streaming, capped)…") |
| | overall = tqdm(total=total_target, desc="[all] collecting", unit="pair", ncols=100) |
| |
|
| | collectors = [ |
| | collect_oasst1, |
| | collect_hh_rlhf, |
| | collect_ultrachat, |
| | collect_dailydialog, |
| | collect_bst, |
| | collect_personachat, |
| | collect_soda, |
| | collect_topical_chat, |
| | collect_shakespeare, |
| | collect_reddit_jokes, |
| | collect_dadjokes, |
| | collect_reddit_sarcasm, |
| | collect_figlang, |
| | collect_showerthoughts, |
| | collect_personas, |
| | collect_tweeteval, |
| | collect_fourchan, |
| | collect_elon_trump, |
| | ] |
| | for fn in collectors: |
| | try: |
| | fn(pairs, overall) |
| | except Exception as e: |
| | print(f"[collector error] {fn.__name__}: {e}") |
| |
|
| | overall.close() |
| | print("[2/6] Deduplicating & clipping…") |
| | seen = set(); deduped = [] |
| | for block in pairs: |
| | try: |
| | bot_line = [ln for ln in block.splitlines() if ln.startswith("Bot:")][0] |
| | key = bot_line[4:].strip().lower() |
| | except Exception: |
| | key = block.strip().lower() |
| | if key in seen: continue |
| | seen.add(key); deduped.append(block) |
| |
|
| | random.shuffle(deduped) |
| | if len(deduped) > MAX_TOTAL_PAIRS: |
| | deduped = deduped[:MAX_TOTAL_PAIRS] |
| |
|
| | out_path = SAVE_DIR / "corpus.txt" |
| | out_path.write_text("\n".join(deduped), encoding="utf-8") |
| | print(f" wrote {len(deduped)} pairs → {out_path}") |
| | return out_path |
| |
|
| | |
| | |
| | |
| | def train_spm(corpus_path: Path) -> spm.SentencePieceProcessor: |
| | print("[3/6] Training SentencePiece tokenizer…") |
| | spm.SentencePieceTrainer.Train( |
| | input=str(corpus_path), |
| | model_prefix=str(TOKENIZER_PREFIX), |
| | vocab_size=VOCAB_SIZE, |
| | model_type="unigram", |
| | character_coverage=1.0, |
| | user_defined_symbols=USER_SYMBOLS, |
| | bos_id=1, eos_id=2, unk_id=0, pad_id=-1 |
| | ) |
| | sp = spm.SentencePieceProcessor() |
| | sp.load(f"{TOKENIZER_PREFIX}.model") |
| | print(f" tokenizer saved at {TOKENIZER_PREFIX}.model") |
| | return sp |
| |
|
| | |
| | |
| | |
| | def encode_corpus_to_ids(sp: spm.SentencePieceProcessor, corpus_path: Path): |
| | print("[4/6] Encoding corpus to token IDs…") |
| | text = corpus_path.read_text(encoding="utf-8") |
| | blocks = [b for b in text.split("\n\n") if b.strip()] |
| | ids = [] |
| | eos = sp.eos_id() |
| | for b in blocks: |
| | for line in b.splitlines(): |
| | if not line.strip(): continue |
| | ids.extend(sp.encode(line, out_type=int)); ids.append(eos) |
| | ids.append(eos) |
| | ids = np.array(ids, dtype=np.int32) |
| | n = len(ids); cut = int(n * 0.97) |
| | train_ids = torch.tensor(ids[:cut], dtype=torch.long, device=DEVICE) |
| | val_ids = torch.tensor(ids[cut:], dtype=torch.long, device=DEVICE) |
| | print(f" tokens: train={train_ids.numel():,}, val={val_ids.numel():,}, vocab={sp.vocab_size()}") |
| | return train_ids, val_ids, sp.vocab_size() |
| |
|
| | |
| | |
| | |
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self, n_embd, n_head, dropout=0.0, block_size=256): |
| | super().__init__() |
| | assert n_embd % n_head == 0 |
| | self.n_head = n_head |
| | self.head_dim = n_embd // n_head |
| | self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False) |
| | self.proj = nn.Linear(n_embd, n_embd, bias=False) |
| | self.attn_drop = nn.Dropout(dropout) |
| | self.resid_drop = nn.Dropout(dropout) |
| | self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size)) |
| | def forward(self, x): |
| | B,T,C = x.shape |
| | qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1) |
| | q = q.view(B,T,self.n_head,self.head_dim).transpose(1,2) |
| | k = k.view(B,T,self.n_head,self.head_dim).transpose(1,2) |
| | v = v.view(B,T,self.n_head,self.head_dim).transpose(1,2) |
| | att = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim) |
| | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) |
| | att = torch.softmax(att, dim=-1) |
| | att = self.attn_drop(att) |
| | y = att @ v |
| | y = y.transpose(1,2).contiguous().view(B,T,C) |
| | y = self.resid_drop(self.proj(y)) |
| | return y |
| |
|
| | class Block(nn.Module): |
| | def __init__(self, n_embd, n_head, dropout=0.0, block_size=256): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(n_embd) |
| | self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size) |
| | self.ln2 = nn.LayerNorm(n_embd) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(n_embd, 4*n_embd), |
| | nn.GELU(), |
| | nn.Linear(4*n_embd, n_embd), |
| | nn.Dropout(dropout), |
| | ) |
| | def forward(self, x): |
| | x = x + self.attn(self.ln1(x)) |
| | x = x + self.mlp(self.ln2(x)) |
| | return x |
| |
|
| | class TinyGPT(nn.Module): |
| | def __init__(self, vocab_size, n_layer, n_head, n_embd, block_size, dropout=0.0): |
| | super().__init__() |
| | self.block_size = block_size |
| | self.tok_emb = nn.Embedding(vocab_size, n_embd) |
| | self.pos_emb = nn.Embedding(block_size, n_embd) |
| | self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout, block_size) for _ in range(n_layer)]) |
| | self.ln_f = nn.LayerNorm(n_embd) |
| | self.head = nn.Linear(n_embd, vocab_size, bias=False) |
| | self.apply(self._init) |
| | def _init(self, m): |
| | if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| | if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) |
| | if isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight); nn.init.zeros_(m.bias) |
| | def forward(self, idx, targets=None): |
| | B,T = idx.shape; assert T <= self.block_size |
| | pos = torch.arange(0, T, device=idx.device) |
| | x = self.tok_emb(idx) + self.pos_emb(pos)[None,:,:] |
| | for blk in self.blocks: x = blk(x) |
| | x = self.ln_f(x) |
| | logits = self.head(x) |
| | loss = None |
| | if targets is not None: |
| | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
| | return logits, loss |
| | @torch.no_grad() |
| | def generate(self, idx, max_new_tokens=200, temperature=0.8, top_k=50, top_p=0.95, repetition_penalty=1.0): |
| | self.eval() |
| | for _ in range(max_new_tokens): |
| | idx_cond = idx[:, -self.block_size:] |
| | logits, _ = self.forward(idx_cond) |
| | logits = logits[:, -1, :] |
| | if repetition_penalty != 1.0: |
| | uniq, _ = torch.unique(idx_cond[0], return_counts=True) |
| | logits[:, uniq] /= repetition_penalty |
| | logits = logits / max(1e-8, temperature) |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | cutoff = v[:, -1].unsqueeze(-1) |
| | logits = torch.where(logits < cutoff, torch.full_like(logits, -1e9), logits) |
| | if top_p is not None: |
| | sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| | probs = torch.softmax(sorted_logits, dim=-1) |
| | cdf = torch.cumsum(probs, dim=-1) |
| | mask = cdf > top_p; mask[:, 0] = False |
| | sorted_logits[mask] = -1e9 |
| | logits = torch.zeros_like(logits).scatter(1, sorted_idx, sorted_logits) |
| | probs = torch.softmax(logits, dim=-1) |
| | next_id = torch.multinomial(probs, num_samples=1) |
| | idx = torch.cat([idx, next_id], dim=1) |
| | return idx |
| |
|
| | |
| | |
| | |
| | def get_batch(split_ids: torch.Tensor, B: int, T: int): |
| | ix = torch.randint(0, split_ids.numel() - T - 1, (B,), device=split_ids.device) |
| | x = torch.stack([split_ids[i:i+T] for i in ix]) |
| | y = torch.stack([split_ids[i+1:i+T+1] for i in ix]) |
| | return x, y |
| |
|
| | |
| | |
| | |
| | def train_model(vocab_size, train_ids, val_ids): |
| | print("[5/6] Training tiny GPT on", DEVICE.type.upper(), "…") |
| | model = TinyGPT(vocab_size, n_layer, n_head, n_embd, block_size, dropout).to(DEVICE) |
| | params_m = sum(p.numel() for p in model.parameters())/1e6 |
| | print(f" params: {params_m:.2f}M") |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, betas=(0.9, 0.95), weight_decay=0.0) |
| |
|
| | use_amp = DEVICE.type == "cuda" |
| | scaler = torch.amp.GradScaler("cuda", enabled=use_amp) |
| | autocast = (lambda: torch.amp.autocast("cuda", dtype=torch.float16)) if use_amp else nullcontext |
| |
|
| | start = time.time() |
| | best_val = float("inf") |
| |
|
| | def get_lr(step): |
| | warmup = max(1, int(train_steps * warmup_ratio)) |
| | if step < warmup: return base_lr * (step+1)/warmup |
| | progress = (step - warmup) / max(1, train_steps - warmup) |
| | return min_lr + 0.5*(base_lr - min_lr)*(1 + math.cos(math.pi * min(1.0, progress))) |
| |
|
| | @torch.no_grad() |
| | def eval_loss(iters=80): |
| | model.eval(); losses=[] |
| | for _ in range(iters): |
| | xb, yb = get_batch(val_ids, min(batch_size, 32), block_size) |
| | with autocast(): |
| | _, loss = model(xb, yb) |
| | losses.append(loss.item()) |
| | model.train() |
| | return float(sum(losses)/len(losses)) |
| |
|
| | model.train(); step = 0 |
| | pbar = tqdm(total=train_steps, ncols=100, desc="[train]") |
| | while step < train_steps and (time.time()-start) < MAX_SECONDS: |
| | lr = get_lr(step) |
| | for pg in optimizer.param_groups: pg["lr"] = lr |
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | total_loss = 0.0 |
| | for _ in range(accum_steps): |
| | xb, yb = get_batch(train_ids, batch_size, block_size) |
| | with autocast(): |
| | _, loss = model(xb, yb) |
| | if use_amp: |
| | scaler.scale(loss / accum_steps).backward() |
| | else: |
| | (loss / accum_steps).backward() |
| | total_loss += loss.item() |
| |
|
| | if use_amp: |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | if use_amp: |
| | scaler.step(optimizer); scaler.update() |
| | else: |
| | optimizer.step() |
| |
|
| | step += 1; pbar.update(1) |
| | if step % log_interval == 0 or step == 1: |
| | pbar.set_postfix(train=f"{total_loss/accum_steps:.3f}", lr=f"{lr:.2e}") |
| |
|
| | if step % eval_every == 0: |
| | vl = eval_loss() |
| | best_val = min(best_val, vl) |
| | print(f" eval loss {vl:.3f} | best {best_val:.3f}") |
| |
|
| | pbar.close() |
| | elapsed = time.time() - start |
| | print(f" done in {elapsed:.1f}s | best val {best_val:.3f}") |
| |
|
| | |
| | ckpt_path = SAVE_DIR / "tinygpt.pt" |
| | torch.save(model.state_dict(), ckpt_path) |
| | (SAVE_DIR / "model_config.json").write_text(json.dumps({ |
| | "vocab_size": int(vocab_size), |
| | "n_layer": n_layer, "n_head": n_head, "n_embd": n_embd, |
| | "block_size": block_size, "dropout": dropout |
| | }, indent=2)) |
| | print(f"[saved] weights → {ckpt_path}") |
| | return model |
| |
|
| | |
| | |
| | |
| | def sample_chat(sp: spm.SentencePieceProcessor, model: TinyGPT, prompt: str, max_new_tokens=200): |
| | prefix = f"You: {prompt}\nBot:" |
| | ids = sp.encode(prefix, out_type=int) |
| | x = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| | with torch.no_grad(): |
| | y = model.generate(x, max_new_tokens=max_new_tokens, temperature=TEMP, top_k=TOP_K, top_p=TOP_P, repetition_penalty=REP_PEN) |
| | return sp.decode(y[0].tolist()) |
| |
|
| | |
| | |
| | |
| | def main(): |
| | |
| | corpus_path = SAVE_DIR / "corpus.txt" |
| | spm_model = SAVE_DIR / "spm_chat.model" |
| | if not corpus_path.exists(): |
| | corpus_path = build_corpus() |
| | else: |
| | print(f"[cache] using {corpus_path}") |
| |
|
| | sp = spm.SentencePieceProcessor() |
| | if not spm_model.exists(): |
| | sp = train_spm(corpus_path) |
| | else: |
| | sp.load(str(spm_model)) |
| | print(f"[cache] using {spm_model}") |
| |
|
| | enc_train = SAVE_DIR / "train_ids.pt" |
| | enc_val = SAVE_DIR / "val_ids.pt" |
| | vocab_txt = SAVE_DIR / "vocab_size.txt" |
| |
|
| | if enc_train.exists() and enc_val.exists() and vocab_txt.exists(): |
| | train_ids = torch.load(enc_train, map_location=DEVICE) |
| | val_ids = torch.load(enc_val, map_location=DEVICE) |
| | vocab_size = int(vocab_txt.read_text()) |
| | print(f"[cache] loaded ids: train={train_ids.numel():,}, val={val_ids.numel():,}, vocab={vocab_size}") |
| | else: |
| | train_ids, val_ids, vocab_size = encode_corpus_to_ids(sp, corpus_path) |
| | torch.save(train_ids, enc_train); torch.save(val_ids, enc_val) |
| | vocab_txt.write_text(str(vocab_size)) |
| | print("[cache] saved encoded ids") |
| |
|
| | model = train_model(vocab_size, train_ids, val_ids) |
| |
|
| | print("\n[6/6] Samples:\n") |
| | prompts = [ |
| | "Give me a spicy take on AI.", |
| | "Roast my messy desk.", |
| | "Explain recursion like you're annoyed.", |
| | "Write a satirical headline about coffee.", |
| | "Give me a shower thought about umbrellas.", |
| | "Tell me a one-liner about deadlines.", |
| | "Stay in Shakespeare mode and flatter me.", |
| | "Reply sarcastically to: I love meetings.", |
| | "What's a good way to say no to a meeting politely?", |
| | "Roleplay as my productivity coach for two turns.", |
| | ] |
| | out_path = SAVE_DIR / "samples.txt" |
| | with out_path.open("w", encoding="utf-8") as f: |
| | for p in prompts: |
| | txt = sample_chat(sp, model, p, max_new_tokens=200) |
| | print("----\n" + txt) |
| | f.write("----\n" + txt + "\n") |
| | print(f"\n[saved] samples → {out_path}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|