File size: 4,139 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# pip install datasets pyarrow regex
import re
import random
from datasets import load_dataset

# ========= 正则 =========
SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S)
TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S)

# 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明:
NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$")

def last_3rounds_user_to_open_assistant(chatml: str) -> str:
    """
    取最近三轮:user → assistant → user → assistant → user → assistant(开放式)
    去掉最前面的 system 段。
    """
    if not isinstance(chatml, str):
        return chatml

    text = SYS_HEAD.sub("", chatml)

    # 非 ChatML 就保守返回
    if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text):
        return text

    # 找到最后一次 assistant(开放式起点)
    last_ast = text.rfind("<|im_start|>assistant")
    if last_ast == -1:
        return text.strip()

    # 开放式 assistant:去掉它后面的 <|im_end|> 及其后续
    final_assistant_open = text[last_ast:]
    final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S)

    # 在开放式之前收集闭合轮次
    head = text[:last_ast]
    turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)]

    if len(turns) < 5:
        # 历史不足三轮:尽力返回 + 开放式
        prefix = "\n".join(t[1] for t in turns)
        if prefix:
            prefix += "\n"
        return prefix + final_assistant_open

    # 取以 user 结尾的最近 5 段:U, A, U, A, U
    j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None)
    if j is None:
        selected = [t[1] for t in turns[-5:]]  # 兜底
    else:
        i = max(0, j - 4)
        selected = [t[1] for t in turns[i:j+1]]

    prefix = ("\n".join(selected) + "\n") if selected else ""
    return prefix + final_assistant_open


# ============ 批处理 + 抽样打印 ============
in_path  = "/home/data/train_v3full.parquet"  # 输入
out_path = "/home/data/train_2round.parquet"       # 输出

ds = load_dataset("parquet", data_files=in_path, split="train")

# 只保留三列
keep_cols = ["chosen_prompt", "chosen", "reject"]
drop_cols = [c for c in ds.column_names if c not in keep_cols]
if drop_cols:
    ds = ds.remove_columns(drop_cols)
def ensure_linebreak_after_assistant(chosen_prompt: str) -> str:
    """
    - <|im_start|>assistant 后必须换行
    - 人名: 后面不换行
    """
    # 1) 如果 assistant 标签后不是换行,就加换行
    chosen_prompt = re.sub(
        r"(<\|im_start\|>assistant)(?!\s*\n)",  # 后面不是换行
        r"\1\n",
        chosen_prompt
    )

    # 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行)
    m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt)
    if m:
        before = m.group(1)
        name_colon = m.group(2)
        chosen_prompt = chosen_prompt.replace(
            before + name_colon + m.group(3),
            before + name_colon + " "
        )
    return chosen_prompt

def _map_fn(ex):
    cp = last_3rounds_user_to_open_assistant(ex["chosen_prompt"])
    cp = ensure_linebreak_after_assistant(cp)
    ex["chosen_prompt"] = cp
    return ex

# 可用 num_proc=4~8 加速(注意内存)
ds = ds.map(_map_fn, desc="Build last 3 rounds (open assistant) + linebreak rules")

ds.to_parquet(out_path)
print(f"✅ Saved -> {out_path}")

# 抽样打印 5 条(原样 + 拼接效果,便于检查是否多空行/人名是否同一行)
idxs = random.sample(range(len(ds)), min(5, len(ds)))
sampled = ds.select(idxs)
for i, ex in enumerate(sampled):
    print(f"===== Sample {i+1} / chosen_prompt 原样 =====")
    print(ex["chosen_prompt"])
    print(f"===== Sample {i+1} / chosen_prompt + chosen =====")
    print(ex["chosen_prompt"] + ex["chosen"])
    print(f"===== Sample {i+1} / chosen_prompt + reject =====")
    print(ex["chosen_prompt"] + ex["reject"])
    print()