from datasets import load_dataset from transformers import T5Tokenizer import pandas as pd, csv, re from tqdm import tqdm # ── Config ──────────────────────────────────────────────────────────────── jsonl_path = "lmsys_chat_1m_full.jsonl" # local file use_subset = False # False ⇒ full 1 M rows num_samples = 500 # if subset max_turn_pairs = 1 # 4 user+assistant = 8 lines max_input_tokens = 512 # fits t5-small/base # ────────────────────────────────────────────────────────────────────────── tok = T5Tokenizer.from_pretrained("t5-small") ds = load_dataset("json", data_files=jsonl_path, split="train") if use_subset: ds = ds.select(range(min(num_samples, len(ds)))) print(f"🔍 subset → {len(ds)} rows") def mostly_ascii(s: str, threshold: float = .3) -> bool: try: return sum(ord(ch) > 127 for ch in s) / len(s) < threshold except ZeroDivisionError: return False def format_turns(conv): return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv] def build_pair(turns, max_tokens=512): if len(turns) < max_turn_pairs * 2: return None # last N pairs use_turns = turns[-(max_turn_pairs * 2):] prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1]) target = use_turns[-1].replace("Assistant: ", "", 1) # --- safe trimming loop -------------------------------------------- for _ in range(max_turn_pairs): # at most 4 trims if max_turn_pairs=4 if len(tok.tokenize(prompt)) <= max_tokens: break # fits → good sep_pos = prompt.find("\n\n", len("chat:\n\n")) if sep_pos == -1: # no more turns to drop return None prompt = "chat:\n\n" + prompt[sep_pos + 2:] else: # still too long after all trims return None # ------------------------------------------------------------------- if len(prompt) < 30 or len(target) < 10: return None if not mostly_ascii(prompt + target): return None return prompt, target rows, kept = [], 0 for ex in tqdm(ds, desc="formatting"): conv = ex.get("conversation") if not isinstance(conv, list): continue p = build_pair(format_turns(conv)) if p: rows.append({"source": p[0], "target": p[1]}) kept += 1 print(f"✅ kept {kept} examples") pd.DataFrame(rows).to_csv( "chat_1turn.csv", index=False, quoting=csv.QUOTE_ALL, # preserves embedded newlines encoding="utf-8" ) print("💾 saved → t5_chat_4turn.csv")