# -*- coding: utf-8 -*-
import argparse
import re
import pandas as pd
from transformers import AutoTokenizer
# 1) 正则:闭合块(<|im_start|>role\n content <|im_end|>)
CLOSED_PAT = re.compile(
r"<\|im_start\|>(system|user|assistant)[ \t]*\n" # 角色 + 强制换行
r"(.*?)" # 内容(多行,非贪婪)
r"<\|im_end\|>", # 直到结束标签
flags=re.DOTALL
)
# 2) 正则:未闭合的 assistant 尾块(直到文本末尾)
OPEN_ASSIST_TAIL = re.compile(
r"<\|im_start\|>assistant[ \t]*\n([\s\S]*)\Z",
flags=re.DOTALL
)
def chatml_to_messages_and_tail(text: str):
"""解析为 messages(仅闭合块)和可能存在的未闭合 assistant 尾块。"""
if text is None:
return [], None
t = str(text)
msgs, last_end = [], 0
for m in CLOSED_PAT.finditer(t):
role = m.group(1)
content = m.group(2).strip("\n") # 保留内部换行,仅去掉首尾多余换行
msgs.append({"role": role, "content": content})
last_end = m.end()
tail = t[last_end:]
m_tail = OPEN_ASSIST_TAIL.search(tail) if tail else None
tail_assistant = m_tail.group(1) if m_tail else None
return msgs, tail_assistant
def transform_one(raw_chatml: str, tok: AutoTokenizer) -> str:
"""
完全基于你的逻辑:
- 闭合块 -> apply_chat_template(add_generation_prompt=False)
- 若有未闭合 assistant -> 直接拼 "<|im_start|>assistant\n\n\n\n\n{tail}"
"""
messages, tail_assistant = chatml_to_messages_and_tail(raw_chatml)
# 渲染闭合块
rendered_closed = tok.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=False
)
# 若有未闭合的 assistant:按你的写法手工追加(不闭合、不改换行)
if tail_assistant is not None:
tail_assistant = tail_assistant.rstrip() # 和你写法一致,去掉尾部空白行
# 若前一串最后没有换行,补一个,避免粘连(保险,不改变你逻辑的输出形态)
# final = rendered_closed + f"<|im_start|>assistant\n\n\n\n\n{tail_assistant}"
final = rendered_closed + f"<|im_start|>assistant\n{tail_assistant}"
else:
final = rendered_closed
return final
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True, help="输入 parquet 路径")
ap.add_argument("--output", required=True, help="输出 parquet 路径")
ap.add_argument("--model", default="deeppin/Qwen3-Reranker-8B-SequenceClassification",
help="用于 apply_chat_template 的 tokenizer 模型名/路径")
ap.add_argument("--column", default="chosen_prompt", help="需要转换的列名")
ap.add_argument("--out_column", default=None,
help="输出列名(不填则覆盖原列)")
args = ap.parse_args()
df = pd.read_parquet(args.input)
if args.column not in df.columns:
raise ValueError(f"找不到列:{args.column}")
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=False)
out_col = args.out_column or args.column
df[out_col] = df[args.column].apply(lambda s: transform_one(s, tok))
df.to_parquet(args.output, index=False)
print(f"Done. Wrote: {args.output} (transformed column: {out_col})")
if __name__ == "__main__":
main()