File size: 3,023 Bytes
7cd7caf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from datasets import load_dataset
from transformers import T5Tokenizer
import pandas as pd, csv, re
from tqdm import tqdm
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
jsonl_path = "lmsys_chat_1m_full.jsonl" # local file
use_subset = False # False β full 1 M rows
num_samples = 500 # if subset
max_turn_pairs = 1 # 4 user+assistant = 8 lines
max_input_tokens = 512 # fits t5-small/base
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
tok = T5Tokenizer.from_pretrained("t5-small")
ds = load_dataset("json", data_files=jsonl_path, split="train")
if use_subset:
ds = ds.select(range(min(num_samples, len(ds))))
print(f"π subset β {len(ds)} rows")
def mostly_ascii(s: str, threshold: float = .3) -> bool:
try:
return sum(ord(ch) > 127 for ch in s) / len(s) < threshold
except ZeroDivisionError:
return False
def format_turns(conv):
return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv]
def build_pair(turns, max_tokens=512):
if len(turns) < max_turn_pairs * 2:
return None
# last N pairs
use_turns = turns[-(max_turn_pairs * 2):]
prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1])
target = use_turns[-1].replace("Assistant: ", "", 1)
# --- safe trimming loop --------------------------------------------
for _ in range(max_turn_pairs): # at most 4 trims if max_turn_pairs=4
if len(tok.tokenize(prompt)) <= max_tokens:
break # fits β good
sep_pos = prompt.find("\n\n", len("chat:\n\n"))
if sep_pos == -1: # no more turns to drop
return None
prompt = "chat:\n\n" + prompt[sep_pos + 2:]
else:
# still too long after all trims
return None
# -------------------------------------------------------------------
if len(prompt) < 30 or len(target) < 10:
return None
if not mostly_ascii(prompt + target):
return None
return prompt, target
rows, kept = [], 0
for ex in tqdm(ds, desc="formatting"):
conv = ex.get("conversation")
if not isinstance(conv, list): continue
p = build_pair(format_turns(conv))
if p:
rows.append({"source": p[0], "target": p[1]})
kept += 1
print(f"β
kept {kept} examples")
pd.DataFrame(rows).to_csv(
"chat_1turn.csv",
index=False,
quoting=csv.QUOTE_ALL, # preserves embedded newlines
encoding="utf-8"
)
print("πΎ saved β t5_chat_4turn.csv")
|