|
|
from datasets import load_dataset
|
|
|
from transformers import T5Tokenizer
|
|
|
import pandas as pd, csv, re
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
jsonl_path = "lmsys_chat_1m_full.jsonl"
|
|
|
use_subset = False
|
|
|
num_samples = 500
|
|
|
max_turn_pairs = 1
|
|
|
max_input_tokens = 512
|
|
|
|
|
|
|
|
|
tok = T5Tokenizer.from_pretrained("t5-small")
|
|
|
ds = load_dataset("json", data_files=jsonl_path, split="train")
|
|
|
|
|
|
if use_subset:
|
|
|
ds = ds.select(range(min(num_samples, len(ds))))
|
|
|
print(f"π subset β {len(ds)} rows")
|
|
|
|
|
|
def mostly_ascii(s: str, threshold: float = .3) -> bool:
|
|
|
try:
|
|
|
return sum(ord(ch) > 127 for ch in s) / len(s) < threshold
|
|
|
except ZeroDivisionError:
|
|
|
return False
|
|
|
|
|
|
def format_turns(conv):
|
|
|
return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv]
|
|
|
|
|
|
def build_pair(turns, max_tokens=512):
|
|
|
if len(turns) < max_turn_pairs * 2:
|
|
|
return None
|
|
|
|
|
|
|
|
|
use_turns = turns[-(max_turn_pairs * 2):]
|
|
|
|
|
|
prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1])
|
|
|
target = use_turns[-1].replace("Assistant: ", "", 1)
|
|
|
|
|
|
|
|
|
for _ in range(max_turn_pairs):
|
|
|
if len(tok.tokenize(prompt)) <= max_tokens:
|
|
|
break
|
|
|
sep_pos = prompt.find("\n\n", len("chat:\n\n"))
|
|
|
if sep_pos == -1:
|
|
|
return None
|
|
|
prompt = "chat:\n\n" + prompt[sep_pos + 2:]
|
|
|
else:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
if len(prompt) < 30 or len(target) < 10:
|
|
|
return None
|
|
|
if not mostly_ascii(prompt + target):
|
|
|
return None
|
|
|
return prompt, target
|
|
|
|
|
|
|
|
|
rows, kept = [], 0
|
|
|
for ex in tqdm(ds, desc="formatting"):
|
|
|
conv = ex.get("conversation")
|
|
|
if not isinstance(conv, list): continue
|
|
|
p = build_pair(format_turns(conv))
|
|
|
if p:
|
|
|
rows.append({"source": p[0], "target": p[1]})
|
|
|
kept += 1
|
|
|
|
|
|
print(f"β
kept {kept} examples")
|
|
|
|
|
|
pd.DataFrame(rows).to_csv(
|
|
|
"chat_1turn.csv",
|
|
|
index=False,
|
|
|
quoting=csv.QUOTE_ALL,
|
|
|
encoding="utf-8"
|
|
|
)
|
|
|
print("πΎ saved β t5_chat_4turn.csv")
|
|
|
|