File size: 5,038 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# pip install datasets pyarrow regex
import re
import random
from datasets import load_dataset

# ========= 正则 =========
SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S)
TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S)

# 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明:
NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$")
in_path  = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet"   # 输入
out_path = "/home/data/raw/test/4201_2355_full_label_1000-8192_sys3round.parquet"   # 输出(改名以示区分)
def join_with_clean_gap(system_block: str, body: str) -> str:
    """
    保留 system 段原样不动。
    保证 system 段与后续对话之间至少有一个换行。
    """
    if not system_block:
        return body
    if system_block.endswith("\n"):
        return system_block + body
    else:
        return system_block + "\n" + body

def last_4rounds_user_to_open_assistant(chatml: str) -> str:
    """
    取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式)
    保留最前面的 system 段。
    """
    if not isinstance(chatml, str):
        return chatml

    # 提取 system 段(原样)
    m_sys = re.match(SYS_HEAD, chatml)
    system_block = m_sys.group(0) if m_sys else ""

    # 去掉 system 后的对话部分
    text = SYS_HEAD.sub("", chatml)

    # 非 ChatML 就保守返回(带上 system)
    if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text):
        return join_with_clean_gap(system_block, text)

    # 找到最后一次 assistant(开放式起点)
    last_ast = text.rfind("<|im_start|>assistant")
    if last_ast == -1:
        return join_with_clean_gap(system_block, text.strip())

    # 开放式 assistant:去掉它后面的 <|im_end|> 及其后续
    final_assistant_open = text[last_ast:]
    final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S)

    # 在开放式之前收集闭合轮次
    head = text[:last_ast]
    turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)]

    # 四轮需要 U,A,U,A,U,A,U 共 7 段历史;不足则尽力返回
    if len(turns) < 7:
        prefix = "\n".join(t[1] for t in turns)
        if prefix:
            prefix += "\n"
        body = prefix + final_assistant_open
        return join_with_clean_gap(system_block, body)

    # 取以 user 结尾的最近 7 段:U, A, U, A, U, A, U
    j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None)
    if j is None:
        selected = [t[1] for t in turns[-7:]]  # 兜底
    else:
        i = max(0, j - 6)  # 需要 7 段 => j-6 .. j
        selected = [t[1] for t in turns[i:j+1]]

    prefix = ("\n".join(selected) + "\n") if selected else ""
    body = prefix + final_assistant_open
    return join_with_clean_gap(system_block, body)

def ensure_linebreak_after_assistant(chosen_prompt: str) -> str:
    """
    - <|im_start|>assistant 后必须换行
    - 人名: 后面不换行
    """
    if not isinstance(chosen_prompt, str):
        return chosen_prompt

    # 1) 如果 assistant 标签后不是换行,就加换行
    chosen_prompt = re.sub(
        r"(<\|im_start\|>assistant)(?!\s*\n)",
        r"\1\n",
        chosen_prompt
    )

    # 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行)
    m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt)
    if m:
        before = m.group(1)
        name_colon = m.group(2)
        chosen_prompt = chosen_prompt.replace(
            before + name_colon + m.group(3),
            before + name_colon + " "
        )
    return chosen_prompt

def _map_fn(ex):
    cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"])
    cp = ensure_linebreak_after_assistant(cp)
    ex["chosen_prompt"] = cp
    return ex

# ============ 批处理 + 保存 + 抽样打印 ============

ds = load_dataset("parquet", data_files=in_path, split="train")

# # 只保留三列
# keep_cols = ["chosen_prompt", "chosen", "reject"]
# drop_cols = [c for c in ds.column_names if c not in keep_cols]
# if drop_cols:
#     ds = ds.remove_columns(drop_cols)

# 可用 num_proc=4~8 加速(注意内存)
ds = ds.map(_map_fn, desc="Keep system + last 4 rounds (open assistant) + linebreak rules")

# 保存到新 parquet 文件
ds.to_parquet(out_path)
print(f"✅ Saved -> {out_path}")

# 抽样打印 5 条
idxs = random.sample(range(len(ds)), min(5, len(ds)))
sampled = ds.select(idxs)
for i, ex in enumerate(sampled):
    print(f"===== Sample {i+1} / chosen_prompt 原样 =====")
    print(ex["chosen_prompt"])
    print(f"===== Sample {i+1} / chosen_prompt + chosen =====")
    print(ex["chosen_prompt"] + ex["chosen"])
    print(f"===== Sample {i+1} / chosen_prompt + reject =====")
    print(ex["chosen_prompt"] + ex["reject"])
    print()