nas / deal_data /qiege.py
yuccaaa's picture
Upload deal_data/qiege.py with huggingface_hub
b776e9e verified
import json
from transformers import AutoTokenizer
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
from tqdm import tqdm
# 本地模型路径
local_model_path = "/nas/shared/kilab/hf-hub/Qwen3-32B"
# 主进程先加载一个 tokenizer,用于估算总token数量(可选)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
def init_tokenizer():
"""为每个子进程加载 tokenizer"""
global tokenizer_worker
tokenizer_worker = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
def process_line(line):
"""处理每一行:JSON解析 + 分割过长文本"""
global tokenizer_worker
try:
data = json.loads(line.strip())
if 'content' in data:
output_text = data["content"]
tokens = tokenizer_worker.encode(output_text)
if len(tokens) > 4096:
chunks = []
current_chunk = []
for token in tokens:
if len(current_chunk) + 1 > 4096:
chunks.append(tokenizer_worker.decode(current_chunk))
current_chunk = [token]
else:
current_chunk.append(token)
if current_chunk:
chunks.append(tokenizer_worker.decode(current_chunk))
return [json.dumps({"content": chunk}, ensure_ascii=False) for chunk in chunks]
else:
return [json.dumps({"content": output_text}, ensure_ascii=False)]
else:
return None
except Exception:
return None
# 输入输出路径
input_file = '/nas/shared/kilab/wangyujia/pretrain_data/cot/clean/merge_cot.jsonl'
output_file = '/nas/shared/kilab/wangyujia/pretrain_data/cot/clean/merge_cot_new.jsonl'
if __name__ == '__main__':
try:
with open(input_file, 'r', encoding='utf-8') as infile:
lines = infile.readlines()
total_lines = len(lines)
with ProcessPoolExecutor(max_workers=multiprocessing.cpu_count(), initializer=init_tokenizer) as executor:
futures = [executor.submit(process_line, line) for line in lines]
with open(output_file, 'w', encoding='utf-8') as outfile, tqdm(total=total_lines, desc="处理进度") as pbar:
for future in as_completed(futures):
result = future.result()
if result:
for r in result:
outfile.write(r + '\n')
pbar.update(1)
print(f"\n✅ 处理完成!共处理 {total_lines} 行,输出保存至 {output_file}")
except FileNotFoundError:
print(f"❌ 文件 {input_file} 未找到。")
except Exception as e:
print(f"❌ 发生错误: {e}")
import traceback
traceback.print_exc()