File size: 3,023 Bytes
7cd7caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from datasets import load_dataset
from transformers import T5Tokenizer
import pandas as pd, csv, re
from tqdm import tqdm

# ── Config ────────────────────────────────────────────────────────────────
jsonl_path       = "lmsys_chat_1m_full.jsonl"   # local file
use_subset       = False                         # False β‡’ full 1 M rows
num_samples      = 500                      # if subset
max_turn_pairs   = 1                           # 4 user+assistant = 8 lines
max_input_tokens = 512                         # fits t5-small/base
# ──────────────────────────────────────────────────────────────────────────

tok = T5Tokenizer.from_pretrained("t5-small")
ds  = load_dataset("json", data_files=jsonl_path, split="train")

if use_subset:
    ds = ds.select(range(min(num_samples, len(ds))))
    print(f"πŸ” subset β†’ {len(ds)} rows")

def mostly_ascii(s: str, threshold: float = .3) -> bool:
    try:
        return sum(ord(ch) > 127 for ch in s) / len(s) < threshold
    except ZeroDivisionError:
        return False

def format_turns(conv):
    return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv]

def build_pair(turns, max_tokens=512):
    if len(turns) < max_turn_pairs * 2:
        return None

    # last N pairs
    use_turns = turns[-(max_turn_pairs * 2):]

    prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1])
    target = use_turns[-1].replace("Assistant: ", "", 1)

    # --- safe trimming loop --------------------------------------------
    for _ in range(max_turn_pairs):          # at most 4 trims if max_turn_pairs=4
        if len(tok.tokenize(prompt)) <= max_tokens:
            break                            # fits β†’ good
        sep_pos = prompt.find("\n\n", len("chat:\n\n"))
        if sep_pos == -1:                    # no more turns to drop
            return None
        prompt = "chat:\n\n" + prompt[sep_pos + 2:]
    else:
        # still too long after all trims
        return None
    # -------------------------------------------------------------------

    if len(prompt) < 30 or len(target) < 10:
        return None
    if not mostly_ascii(prompt + target):
        return None
    return prompt, target


rows, kept = [], 0
for ex in tqdm(ds, desc="formatting"):
    conv = ex.get("conversation")
    if not isinstance(conv, list): continue
    p = build_pair(format_turns(conv))
    if p:
        rows.append({"source": p[0], "target": p[1]})
        kept += 1

print(f"βœ… kept {kept} examples")

pd.DataFrame(rows).to_csv(
    "chat_1turn.csv",
    index=False,
    quoting=csv.QUOTE_ALL,     # preserves embedded newlines
    encoding="utf-8"
)
print("πŸ’Ύ saved β†’ t5_chat_4turn.csv")