File size: 9,556 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | import re
from typing import List, Dict
from tqdm import tqdm
# === 替换为你的输入路径和输出路径 ===
input_path = "/home/data/raw/test/1159-L6.parquet"
output_path = "/home/data/raw/test/1159-L6_format.parquet"
EOT_TAIL = re.compile(r"<\|eot_id\|>\s*$")
TAIL_TAGS = re.compile(r"[<]?\|eot_id\|[>]?\s*$") # 匹配 <|eot_id|> 或 |eot_id|(仅尾部)
PIPE_TRAIL = re.compile(r"(?:\|[A-Za-z0-9_]+\|[^\n]*)\s*$") # 匹配 |xxx| 以及其后所有内容(尾部杂项)
def is_mistral_format(text):
return "<|im_start|>" in text and "<|im_end|>" in text
def convert_to_mistral_format(text: str, add_generation_prompt: bool = False) -> str:
# 1. 如果已经是 ChatML 格式,直接返回(或判断是否结尾需要修复)
# if "<|im_start|>" in text and "<|im_end|>" in text:
# # ✅ 把 <|im_end|> 前的任意空白(含换行)吞掉,只留一个换行
# text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text)
# # 再把全文末尾的 <|im_end|> 去掉(保持开放式)
# text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip()
# return text
if "<|im_start|>" in text and "<|im_end|>" in text:
text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text)
text = re.sub(r"\n{3,}", "\n\n", text) # 压缩 3+ 连续空行
text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip()
return text
output = ""
# === 2. LLaMA格式 ===
if "<|start_header_id|>" in text and "<|end_header_id|>" in text:
segments = re.split(r"<\|start_header_id\|>(.*?)<\|end_header_id\|>", text, flags=re.S)
role_content_pairs = []
for i in range(1, len(segments), 2):
role = segments[i].strip()
content_block = segments[i + 1].strip()
# 按 <|eot_id|> 切分
for part in re.split(r"<\|eot_id\|>", content_block):
part = part.strip()
if part:
role_content_pairs.append((role, part))
# 写成 ChatML
for idx, (role, content) in enumerate(role_content_pairs):
is_last_pair = idx == len(role_content_pairs) - 1
if role == "system":
output += f"<|im_start|>system\n{content}\n<|im_end|>\n"
elif role == "user":
output += f"<|im_start|>user\n{content}\n<|im_end|>\n"
elif role == "assistant":
if is_last_pair: # 🔑 仅最后一段 assistant 不闭合
# 若 content 末尾已有换行,就别再补
if not content.endswith("\n"):
content += "\n"
output += f"<|im_start|>assistant\n{content}"
else:
output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"
# === 3. [INST] 格式(MistralV2 / Ministral)===
elif "[INST]" in text and "[/INST]" in text:
system_match = re.search(r"\[SYSTEM_PROMPT\](.*?)\[/SYSTEM_PROMPT\]", text, re.S)
if system_match:
output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n"
turns = re.findall(r"\[INST\](.*?)\[/INST\](.*?)(?=(\[INST\]|</s>|$))", text, re.S)
for user_msg, assistant_msg, _ in turns:
output += f"<|im_start|>user\n{user_msg.strip()}\n<|im_end|>\n"
if assistant_msg.strip():
output += f"<|im_start|>assistant\n{assistant_msg.strip()}\n<|im_end|>\n"
else:
output += f"<|im_start|>assistant\n"
# === 4. <start_of_turn> 格式(Gemma)===
elif "<start_of_turn>" in text:
# ➊ system
system_match = re.search(r"\[System:(.*?)\]", text, re.S)
if system_match:
output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n"
# ➋ turns
turns = re.findall(r"<start_of_turn>(user|model)\s*\n?(.*?)<end_of_turn>", text, re.S)
for idx, (role, content) in enumerate(turns):
role = "assistant" if role == "model" else "user"
is_last = idx == len(turns) - 1
if role == "assistant" and is_last:
# 留开放式 assistant
if not content.endswith("\n"):
content += "\n"
output += f"<|im_start|>assistant\n{content}"
else:
output += f"<|im_start|>{role}\n{content.strip()}\n<|im_end|>\n"
# === 5. Pygmalion 格式 ===
elif "<start>" in text or re.search(r"(?m)^You[::]|^.*?[::].*?$", text):
# ➊ system
persona_match = re.search(r"(.*?)<start>", text, re.S)
if persona_match:
output += f"<|im_start|>system\n{persona_match.group(1).strip()}\n<|im_end|>\n"
# ➋ dialogue
dialogue = text.split("<start>")[-1]
lines = [l.strip() for l in dialogue.strip().split("\n") if ":" in l]
for idx, line in enumerate(lines):
is_last = idx == len(lines) - 1
if re.match(r"^(You|User|你)[::]", line):
content = re.sub(r"^(You|User|你)[::]", "", line).strip()
output += f"<|im_start|>user\n{content}\n<|im_end|>\n"
else:
_, content = line.split(":", 1)
content = content.strip()
if is_last:
# 最后一条且是 assistant → 不闭合
if not content.endswith("\n"):
content += "\n"
output += f"<|im_start|>assistant\n{content}"
else:
output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"
# === 6. 兜底 ===
else:
return None
# === ✅ 最后的修正:如果最后是空的 assistant ===
output = output.strip()
# ✅ 去掉末尾的空 assistant(即 <|im_start|>assistant\n<|im_end|>)
if output.endswith("<|im_start|>assistant\n<|im_end|>"):
output = output[:-len("<|im_end|>")].rstrip()
# ✅ 进一步修复 assistant 回复为人名提示时的错误加 <|im_end|>
# 例子:<|im_start|>assistant\nFlo:<|im_end|> → <|im_start|>assistant\nFlo:
# 条件:结尾是 "<|im_start|>assistant\n(一行人名或词语):<|im_end|>"
last_assistant_pattern = r"<\|im_start\|>assistant\n([^\n<\|]{1,100}):\s*<\|im_end\|>$"
if re.search(last_assistant_pattern, output):
output = re.sub(r"<\|im_end\|>$", "", output).rstrip()
# ✅ 如果用户希望加入 generation prompt(即以 <|im_start|>assistant 开始下一轮)
if add_generation_prompt and not output.endswith("<|im_start|>assistant"):
output += f"\n<|im_start|>assistant"
return output.strip()
# 处理整个数据集(列表)
def standardize_dataset_to_mistral_format(dataset: List[Dict]) -> List[Dict]:
converted = []
for sample in tqdm(dataset):
text = sample.get("text", "")
new_text = convert_to_mistral_format(text)
if new_text:
converted.append({"text": new_text})
return converted
from datasets import load_dataset
from tqdm import tqdm
import re
def clean_chosen_tail(text: str) -> str:
"""删除 chosen 末尾所有非正常内容(标记符、空白等),保留正文"""
if not isinstance(text, str):
return text
# 先去掉尾部的 eot 标记(两种写法都支持)
text = TAIL_TAGS.sub("", text)
# 再把形如 |start_header_id| 这类尾部管道标记及其后的内容都去掉
text = PIPE_TRAIL.sub("", text)
return text.rstrip()
def apply_format_conversion(example):
# ① 先清理 chosen/reject 中尾部的 eot(如果你还想保留只尾部删 eot 的逻辑)
for k in ("chosen", "reject"):
if isinstance(example[k], str):
example[k] = EOT_TAIL.sub("", example[k])
# ② 对 chosen 再做更严格的“尾部全剃掉”清理(标记符、空白等)
if isinstance(example.get("chosen"), str):
example["chosen"] = clean_chosen_tail(example["chosen"])
# 如果也想对 reject 做同样处理,顺便加一行:
example["reject"] = clean_chosen_tail(example["reject"])
# ③ 把 prompt 转成 ChatML
new_prompt = convert_to_mistral_format(example["chosen_prompt"], add_generation_prompt=False)
if new_prompt is None:
return None
example["chosen_prompt"] = new_prompt
return example
# === 加载并处理数据集 ===
dataset = load_dataset("parquet", data_files=input_path, split="train")
converted_dataset = dataset.map(apply_format_conversion)
converted_dataset = converted_dataset.filter(lambda example: example is not None)
# columns_to_keep = ['chosen', 'chosen_prompt', 'reject']
# converted_dataset = converted_dataset.remove_columns([col for col in converted_dataset.column_names if col not in columns_to_keep])
converted_dataset.to_pandas().to_parquet(output_path, index=False)
print(f"✅ 处理完成,共保留样本 {len(converted_dataset)} 条,已保存至:{output_path}")
import random
# 1. 加载处理好的 parquet 文件
dataset1 = load_dataset("parquet", data_files=output_path, split="train")
indices = random.sample(range(len(dataset1)), 15)
samples = dataset1.select(indices)
# 完整打印
for idx, item in zip(indices, samples):
print(f"\n=== Sample index {idx} ===")
for key, value in item.items():
print(f"[{key}]")
print(value) # 直接原样输出
print("-" * 60)
|