import os import torch from datasets import load_dataset from transformers import ( AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, set_seed ) # --- 配置参数 --- DOMAIN_TEXT_FILE = "/home/hsichen/part_time/BERT_finetune/dataset_pretrain/domain_corpus.txt" MODEL_NAME = "valuesimplex-ai-lab/FinBERT2-base" OUTPUT_DIR = "./bert_dapt_model" # 预训练超参数 DAPT_LR = 1e-5 # 较低的学习率,防止破坏原有知识 DAPT_EPOCHS = 3 # 适中的训练轮数 BATCH_SIZE = 16 # 批次大小 (请根据您的 GPU 显存调整) MLM_PROBABILITY = 0.15 # 掩码比例 SEED = 42 NUM_PROC = 64 # 并行处理的进程数 # 设置随机种子以保证结果可复现 set_seed(SEED) def domain_adaptive_pretrain(): # 路径检查 if not os.path.exists(DOMAIN_TEXT_FILE): print(f"致命错误:领域语料库文件未找到在 {DOMAIN_TEXT_FILE}。请先运行数据预处理脚本。") return # 1. 加载模型和分词器 print("--- 1. 加载模型和分词器 ---") # AutoTokenizer 会自动识别模型对应的分词器 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # AutoModelForMaskedLM 专门用于 MLM 任务 model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME) # 2. 加载和处理文本数据集 print("--- 2. 加载和处理文本数据集 ---") # 使用 datasets 库加载纯文本文件 # 文件必须包含在 'train' 键下,以支持 Trainer raw_datasets = load_dataset("text", data_files={"train": DOMAIN_TEXT_FILE}) # 定义 tokenization 函数 def tokenize_function(examples): # 截断但不填充,因为 DataCollatorForLanguageModeling 会处理填充 return tokenizer( examples["text"], truncation=True, max_length=512, # 推荐修改 return_special_tokens_mask=True ) tokenized_datasets = raw_datasets.map( tokenize_function, batched=True, remove_columns=["text"], num_proc=NUM_PROC ) # 将长文本切块 (Chunking) 和分组 (Grouping) def group_texts(examples): # 拼接所有文本 concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # 设定切块大小 chunk_size = 512 # print(f"Total length: {total_length}, after chunking: {total_length // chunk_size}") # 通过截断 total_length 来丢弃最后一个不完整的切块 total_length = (total_length // chunk_size) * chunk_size # 将文本切分成 max_length (512) 的块 result = { k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)] for k, t in concatenated_examples.items() } # 标签 ID 设为 input_ids,DataCollator 会将非掩码位置设置为 -100 result["labels"] = result["input_ids"].copy() return result # 最终的 DAPT 训练数据集 lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=NUM_PROC ) # # 3. 数据收集器 (动态掩码) # # 这个 Collator 会在每个批次中随机应用 15% 的掩码 data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=MLM_PROBABILITY ) # 4. 设置训练参数 print("--- 3. 设置训练参数 ---") training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=DAPT_EPOCHS, per_device_train_batch_size=BATCH_SIZE, learning_rate=DAPT_LR, weight_decay=0.01, logging_steps=50, save_strategy="epoch", report_to="wandb", ) # 5. 初始化 Trainer trainer = Trainer( model=model, args=training_args, train_dataset=lm_datasets["train"], data_collator=data_collator, ) # 6. 开始继续预训练 print("--- 4. 开始继续预训练 ---") trainer.train() # 7. 保存 DAPT 模型 trainer.save_model(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) print(f"DAPT 模型已保存至: {OUTPUT_DIR}") if __name__ == "__main__": domain_adaptive_pretrain()