import json from transformers import AutoTokenizer from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing from tqdm import tqdm # 本地模型路径 local_model_path = "/nas/shared/kilab/hf-hub/Qwen3-32B" # 主进程先加载一个 tokenizer,用于估算总token数量(可选) tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) def init_tokenizer(): """为每个子进程加载 tokenizer""" global tokenizer_worker tokenizer_worker = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) def process_line(line): """处理每一行:JSON解析 + 分割过长文本""" global tokenizer_worker try: data = json.loads(line.strip()) if 'content' in data: output_text = data["content"] tokens = tokenizer_worker.encode(output_text) if len(tokens) > 4096: chunks = [] current_chunk = [] for token in tokens: if len(current_chunk) + 1 > 4096: chunks.append(tokenizer_worker.decode(current_chunk)) current_chunk = [token] else: current_chunk.append(token) if current_chunk: chunks.append(tokenizer_worker.decode(current_chunk)) return [json.dumps({"content": chunk}, ensure_ascii=False) for chunk in chunks] else: return [json.dumps({"content": output_text}, ensure_ascii=False)] else: return None except Exception: return None # 输入输出路径 input_file = '/nas/shared/kilab/wangyujia/pretrain_data/cot/clean/merge_cot.jsonl' output_file = '/nas/shared/kilab/wangyujia/pretrain_data/cot/clean/merge_cot_new.jsonl' if __name__ == '__main__': try: with open(input_file, 'r', encoding='utf-8') as infile: lines = infile.readlines() total_lines = len(lines) with ProcessPoolExecutor(max_workers=multiprocessing.cpu_count(), initializer=init_tokenizer) as executor: futures = [executor.submit(process_line, line) for line in lines] with open(output_file, 'w', encoding='utf-8') as outfile, tqdm(total=total_lines, desc="处理进度") as pbar: for future in as_completed(futures): result = future.result() if result: for r in result: outfile.write(r + '\n') pbar.update(1) print(f"\n✅ 处理完成!共处理 {total_lines} 行,输出保存至 {output_file}") except FileNotFoundError: print(f"❌ 文件 {input_file} 未找到。") except Exception as e: print(f"❌ 发生错误: {e}") import traceback traceback.print_exc()