rm_code / format.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
import re
from typing import List, Dict
from tqdm import tqdm
# === 替换为你的输入路径和输出路径 ===
input_path = "/home/data/raw/test/1159-L6.parquet"
output_path = "/home/data/raw/test/1159-L6_format.parquet"
EOT_TAIL = re.compile(r"<\|eot_id\|>\s*$")
TAIL_TAGS = re.compile(r"[<]?\|eot_id\|[>]?\s*$") # 匹配 <|eot_id|> 或 |eot_id|(仅尾部)
PIPE_TRAIL = re.compile(r"(?:\|[A-Za-z0-9_]+\|[^\n]*)\s*$") # 匹配 |xxx| 以及其后所有内容(尾部杂项)
def is_mistral_format(text):
return "<|im_start|>" in text and "<|im_end|>" in text
def convert_to_mistral_format(text: str, add_generation_prompt: bool = False) -> str:
# 1. 如果已经是 ChatML 格式,直接返回(或判断是否结尾需要修复)
# if "<|im_start|>" in text and "<|im_end|>" in text:
# # ✅ 把 <|im_end|> 前的任意空白(含换行)吞掉,只留一个换行
# text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text)
# # 再把全文末尾的 <|im_end|> 去掉(保持开放式)
# text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip()
# return text
if "<|im_start|>" in text and "<|im_end|>" in text:
text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text)
text = re.sub(r"\n{3,}", "\n\n", text) # 压缩 3+ 连续空行
text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip()
return text
output = ""
# === 2. LLaMA格式 ===
if "<|start_header_id|>" in text and "<|end_header_id|>" in text:
segments = re.split(r"<\|start_header_id\|>(.*?)<\|end_header_id\|>", text, flags=re.S)
role_content_pairs = []
for i in range(1, len(segments), 2):
role = segments[i].strip()
content_block = segments[i + 1].strip()
# 按 <|eot_id|> 切分
for part in re.split(r"<\|eot_id\|>", content_block):
part = part.strip()
if part:
role_content_pairs.append((role, part))
# 写成 ChatML
for idx, (role, content) in enumerate(role_content_pairs):
is_last_pair = idx == len(role_content_pairs) - 1
if role == "system":
output += f"<|im_start|>system\n{content}\n<|im_end|>\n"
elif role == "user":
output += f"<|im_start|>user\n{content}\n<|im_end|>\n"
elif role == "assistant":
if is_last_pair: # 🔑 仅最后一段 assistant 不闭合
# 若 content 末尾已有换行,就别再补
if not content.endswith("\n"):
content += "\n"
output += f"<|im_start|>assistant\n{content}"
else:
output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"
# === 3. [INST] 格式(MistralV2 / Ministral)===
elif "[INST]" in text and "[/INST]" in text:
system_match = re.search(r"\[SYSTEM_PROMPT\](.*?)\[/SYSTEM_PROMPT\]", text, re.S)
if system_match:
output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n"
turns = re.findall(r"\[INST\](.*?)\[/INST\](.*?)(?=(\[INST\]|</s>|$))", text, re.S)
for user_msg, assistant_msg, _ in turns:
output += f"<|im_start|>user\n{user_msg.strip()}\n<|im_end|>\n"
if assistant_msg.strip():
output += f"<|im_start|>assistant\n{assistant_msg.strip()}\n<|im_end|>\n"
else:
output += f"<|im_start|>assistant\n"
# === 4. <start_of_turn> 格式(Gemma)===
elif "<start_of_turn>" in text:
# ➊ system
system_match = re.search(r"\[System:(.*?)\]", text, re.S)
if system_match:
output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n"
# ➋ turns
turns = re.findall(r"<start_of_turn>(user|model)\s*\n?(.*?)<end_of_turn>", text, re.S)
for idx, (role, content) in enumerate(turns):
role = "assistant" if role == "model" else "user"
is_last = idx == len(turns) - 1
if role == "assistant" and is_last:
# 留开放式 assistant
if not content.endswith("\n"):
content += "\n"
output += f"<|im_start|>assistant\n{content}"
else:
output += f"<|im_start|>{role}\n{content.strip()}\n<|im_end|>\n"
# === 5. Pygmalion 格式 ===
elif "<start>" in text or re.search(r"(?m)^You[::]|^.*?[::].*?$", text):
# ➊ system
persona_match = re.search(r"(.*?)<start>", text, re.S)
if persona_match:
output += f"<|im_start|>system\n{persona_match.group(1).strip()}\n<|im_end|>\n"
# ➋ dialogue
dialogue = text.split("<start>")[-1]
lines = [l.strip() for l in dialogue.strip().split("\n") if ":" in l]
for idx, line in enumerate(lines):
is_last = idx == len(lines) - 1
if re.match(r"^(You|User|你)[::]", line):
content = re.sub(r"^(You|User|你)[::]", "", line).strip()
output += f"<|im_start|>user\n{content}\n<|im_end|>\n"
else:
_, content = line.split(":", 1)
content = content.strip()
if is_last:
# 最后一条且是 assistant → 不闭合
if not content.endswith("\n"):
content += "\n"
output += f"<|im_start|>assistant\n{content}"
else:
output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"
# === 6. 兜底 ===
else:
return None
# === ✅ 最后的修正:如果最后是空的 assistant ===
output = output.strip()
# ✅ 去掉末尾的空 assistant(即 <|im_start|>assistant\n<|im_end|>)
if output.endswith("<|im_start|>assistant\n<|im_end|>"):
output = output[:-len("<|im_end|>")].rstrip()
# ✅ 进一步修复 assistant 回复为人名提示时的错误加 <|im_end|>
# 例子:<|im_start|>assistant\nFlo:<|im_end|> → <|im_start|>assistant\nFlo:
# 条件:结尾是 "<|im_start|>assistant\n(一行人名或词语):<|im_end|>"
last_assistant_pattern = r"<\|im_start\|>assistant\n([^\n<\|]{1,100}):\s*<\|im_end\|>$"
if re.search(last_assistant_pattern, output):
output = re.sub(r"<\|im_end\|>$", "", output).rstrip()
# ✅ 如果用户希望加入 generation prompt(即以 <|im_start|>assistant 开始下一轮)
if add_generation_prompt and not output.endswith("<|im_start|>assistant"):
output += f"\n<|im_start|>assistant"
return output.strip()
# 处理整个数据集(列表)
def standardize_dataset_to_mistral_format(dataset: List[Dict]) -> List[Dict]:
converted = []
for sample in tqdm(dataset):
text = sample.get("text", "")
new_text = convert_to_mistral_format(text)
if new_text:
converted.append({"text": new_text})
return converted
from datasets import load_dataset
from tqdm import tqdm
import re
def clean_chosen_tail(text: str) -> str:
"""删除 chosen 末尾所有非正常内容(标记符、空白等),保留正文"""
if not isinstance(text, str):
return text
# 先去掉尾部的 eot 标记(两种写法都支持)
text = TAIL_TAGS.sub("", text)
# 再把形如 |start_header_id| 这类尾部管道标记及其后的内容都去掉
text = PIPE_TRAIL.sub("", text)
return text.rstrip()
def apply_format_conversion(example):
# ① 先清理 chosen/reject 中尾部的 eot(如果你还想保留只尾部删 eot 的逻辑)
for k in ("chosen", "reject"):
if isinstance(example[k], str):
example[k] = EOT_TAIL.sub("", example[k])
# ② 对 chosen 再做更严格的“尾部全剃掉”清理(标记符、空白等)
if isinstance(example.get("chosen"), str):
example["chosen"] = clean_chosen_tail(example["chosen"])
# 如果也想对 reject 做同样处理,顺便加一行:
example["reject"] = clean_chosen_tail(example["reject"])
# ③ 把 prompt 转成 ChatML
new_prompt = convert_to_mistral_format(example["chosen_prompt"], add_generation_prompt=False)
if new_prompt is None:
return None
example["chosen_prompt"] = new_prompt
return example
# === 加载并处理数据集 ===
dataset = load_dataset("parquet", data_files=input_path, split="train")
converted_dataset = dataset.map(apply_format_conversion)
converted_dataset = converted_dataset.filter(lambda example: example is not None)
# columns_to_keep = ['chosen', 'chosen_prompt', 'reject']
# converted_dataset = converted_dataset.remove_columns([col for col in converted_dataset.column_names if col not in columns_to_keep])
converted_dataset.to_pandas().to_parquet(output_path, index=False)
print(f"✅ 处理完成,共保留样本 {len(converted_dataset)} 条,已保存至:{output_path}")
import random
# 1. 加载处理好的 parquet 文件
dataset1 = load_dataset("parquet", data_files=output_path, split="train")
indices = random.sample(range(len(dataset1)), 15)
samples = dataset1.select(indices)
# 完整打印
for idx, item in zip(indices, samples):
print(f"\n=== Sample index {idx} ===")
for key, value in item.items():
print(f"[{key}]")
print(value) # 直接原样输出
print("-" * 60)