File size: 9,556 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import re
from typing import List, Dict
from tqdm import tqdm

# === 替换为你的输入路径和输出路径 ===
input_path  = "/home/data/raw/test/1159-L6.parquet"
output_path = "/home/data/raw/test/1159-L6_format.parquet"

EOT_TAIL = re.compile(r"<\|eot_id\|>\s*$")   
TAIL_TAGS = re.compile(r"[<]?\|eot_id\|[>]?\s*$")          # 匹配 <|eot_id|> 或 |eot_id|(仅尾部)
PIPE_TRAIL = re.compile(r"(?:\|[A-Za-z0-9_]+\|[^\n]*)\s*$")    # 匹配 |xxx| 以及其后所有内容(尾部杂项)
def is_mistral_format(text):
    return "<|im_start|>" in text and "<|im_end|>" in text

def convert_to_mistral_format(text: str, add_generation_prompt: bool = False) -> str:
    # 1. 如果已经是 ChatML 格式,直接返回(或判断是否结尾需要修复)
    # if "<|im_start|>" in text and "<|im_end|>" in text:
    #     # ✅ 把 <|im_end|> 前的任意空白(含换行)吞掉,只留一个换行
    #     text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text)
    #     # 再把全文末尾的 <|im_end|> 去掉(保持开放式)
    #     text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip()
    #     return text
    if "<|im_start|>" in text and "<|im_end|>" in text:
        text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text)
        text = re.sub(r"\n{3,}", "\n\n", text)              # 压缩 3+ 连续空行
        text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip()
        return text
    output = ""

   
    # === 2. LLaMA格式 ===
    if "<|start_header_id|>" in text and "<|end_header_id|>" in text:
        segments = re.split(r"<\|start_header_id\|>(.*?)<\|end_header_id\|>", text, flags=re.S)

        role_content_pairs = []
        for i in range(1, len(segments), 2):
            role = segments[i].strip()
            content_block = segments[i + 1].strip()

            # 按 <|eot_id|> 切分
            for part in re.split(r"<\|eot_id\|>", content_block):
                part = part.strip()
                if part:
                    role_content_pairs.append((role, part))

        # 写成 ChatML
        for idx, (role, content) in enumerate(role_content_pairs):
            is_last_pair = idx == len(role_content_pairs) - 1

            if role == "system":
                output += f"<|im_start|>system\n{content}\n<|im_end|>\n"

            elif role == "user":
                output += f"<|im_start|>user\n{content}\n<|im_end|>\n"

            elif role == "assistant":
                if is_last_pair:             # 🔑 仅最后一段 assistant 不闭合
                    # 若 content 末尾已有换行,就别再补
                    if not content.endswith("\n"):
                        content += "\n"
                    output += f"<|im_start|>assistant\n{content}"
                else:
                    output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"

    # === 3. [INST] 格式(MistralV2 / Ministral)===
    elif "[INST]" in text and "[/INST]" in text:
        system_match = re.search(r"\[SYSTEM_PROMPT\](.*?)\[/SYSTEM_PROMPT\]", text, re.S)
        if system_match:
            output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n"

        turns = re.findall(r"\[INST\](.*?)\[/INST\](.*?)(?=(\[INST\]|</s>|$))", text, re.S)
        for user_msg, assistant_msg, _ in turns:
            output += f"<|im_start|>user\n{user_msg.strip()}\n<|im_end|>\n"
            if assistant_msg.strip():
                output += f"<|im_start|>assistant\n{assistant_msg.strip()}\n<|im_end|>\n"
            else:
                output += f"<|im_start|>assistant\n"


    # === 4. <start_of_turn> 格式(Gemma)===
    elif "<start_of_turn>" in text:
        # ➊ system
        system_match = re.search(r"\[System:(.*?)\]", text, re.S)
        if system_match:
            output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n"

        # ➋ turns
        turns = re.findall(r"<start_of_turn>(user|model)\s*\n?(.*?)<end_of_turn>", text, re.S)

        for idx, (role, content) in enumerate(turns):
            role = "assistant" if role == "model" else "user"
            is_last = idx == len(turns) - 1

            if role == "assistant" and is_last:
                # 留开放式 assistant
                if not content.endswith("\n"):
                    content += "\n"
                output += f"<|im_start|>assistant\n{content}"
            else:
                output += f"<|im_start|>{role}\n{content.strip()}\n<|im_end|>\n"


    # === 5. Pygmalion 格式 ===
    elif "<start>" in text or re.search(r"(?m)^You[::]|^.*?[::].*?$", text):
    # ➊ system
        persona_match = re.search(r"(.*?)<start>", text, re.S)
        if persona_match:
            output += f"<|im_start|>system\n{persona_match.group(1).strip()}\n<|im_end|>\n"

        # ➋ dialogue
        dialogue = text.split("<start>")[-1]
        lines = [l.strip() for l in dialogue.strip().split("\n") if ":" in l]

        for idx, line in enumerate(lines):
            is_last = idx == len(lines) - 1

            if re.match(r"^(You|User|你)[::]", line):
                content = re.sub(r"^(You|User|你)[::]", "", line).strip()
                output += f"<|im_start|>user\n{content}\n<|im_end|>\n"
            else:
                _, content = line.split(":", 1)
                content = content.strip()
                if is_last:
                    # 最后一条且是 assistant → 不闭合
                    if not content.endswith("\n"):
                        content += "\n"
                    output += f"<|im_start|>assistant\n{content}"
                else:
                    output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"
    # === 6. 兜底 ===
    else:
        return None

    # === ✅ 最后的修正:如果最后是空的 assistant ===
    output = output.strip()

    # ✅ 去掉末尾的空 assistant(即 <|im_start|>assistant\n<|im_end|>)
    if output.endswith("<|im_start|>assistant\n<|im_end|>"):
        output = output[:-len("<|im_end|>")].rstrip()

    # ✅ 进一步修复 assistant 回复为人名提示时的错误加 <|im_end|>
    # 例子:<|im_start|>assistant\nFlo:<|im_end|> → <|im_start|>assistant\nFlo:
    # 条件:结尾是 "<|im_start|>assistant\n(一行人名或词语):<|im_end|>"
    last_assistant_pattern = r"<\|im_start\|>assistant\n([^\n<\|]{1,100}):\s*<\|im_end\|>$"
    if re.search(last_assistant_pattern, output):
        output = re.sub(r"<\|im_end\|>$", "", output).rstrip()

    # ✅ 如果用户希望加入 generation prompt(即以 <|im_start|>assistant 开始下一轮)
    if add_generation_prompt and not output.endswith("<|im_start|>assistant"):
        output += f"\n<|im_start|>assistant"

    return output.strip()

# 处理整个数据集(列表)
def standardize_dataset_to_mistral_format(dataset: List[Dict]) -> List[Dict]:
    converted = []
    for sample in tqdm(dataset):
        text = sample.get("text", "")
        new_text = convert_to_mistral_format(text)
        if new_text:
            converted.append({"text": new_text})
    return converted

from datasets import load_dataset
from tqdm import tqdm
import re


def clean_chosen_tail(text: str) -> str:
    """删除 chosen 末尾所有非正常内容(标记符、空白等),保留正文"""
    if not isinstance(text, str):
        return text
    # 先去掉尾部的 eot 标记(两种写法都支持)
    text = TAIL_TAGS.sub("", text)
    # 再把形如 |start_header_id| 这类尾部管道标记及其后的内容都去掉
    text = PIPE_TRAIL.sub("", text)
    return text.rstrip()

def apply_format_conversion(example):
    # ① 先清理 chosen/reject 中尾部的 eot(如果你还想保留只尾部删 eot 的逻辑)
    for k in ("chosen", "reject"):
        if isinstance(example[k], str):
            example[k] = EOT_TAIL.sub("", example[k])

    # ② 对 chosen 再做更严格的“尾部全剃掉”清理(标记符、空白等)
    if isinstance(example.get("chosen"), str):
        example["chosen"] = clean_chosen_tail(example["chosen"])
        # 如果也想对 reject 做同样处理,顺便加一行:
        example["reject"] = clean_chosen_tail(example["reject"])

    # ③ 把 prompt 转成 ChatML
    new_prompt = convert_to_mistral_format(example["chosen_prompt"], add_generation_prompt=False)
    if new_prompt is None:
        return None
    example["chosen_prompt"] = new_prompt
    return example

# === 加载并处理数据集 ===
dataset = load_dataset("parquet", data_files=input_path, split="train")
converted_dataset = dataset.map(apply_format_conversion)
converted_dataset = converted_dataset.filter(lambda example: example is not None)
# columns_to_keep = ['chosen', 'chosen_prompt', 'reject']
# converted_dataset = converted_dataset.remove_columns([col for col in converted_dataset.column_names if col not in columns_to_keep])
converted_dataset.to_pandas().to_parquet(output_path, index=False)

print(f"✅ 处理完成,共保留样本 {len(converted_dataset)} 条,已保存至:{output_path}")

import random

# 1. 加载处理好的 parquet 文件
dataset1 = load_dataset("parquet", data_files=output_path, split="train")

indices = random.sample(range(len(dataset1)), 15)
samples = dataset1.select(indices)

# 完整打印
for idx, item in zip(indices, samples):
    print(f"\n=== Sample index {idx} ===")
    for key, value in item.items():
        print(f"[{key}]")
        print(value)            # 直接原样输出
        print("-" * 60)