FINBERT2_finetune / continue_pretrain.py
Riverise's picture
Upload folder using huggingface_hub
fc9ae4e verified
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed
)
# --- 配置参数 ---
DOMAIN_TEXT_FILE = "/home/hsichen/part_time/BERT_finetune/dataset_pretrain/domain_corpus.txt"
MODEL_NAME = "valuesimplex-ai-lab/FinBERT2-base"
OUTPUT_DIR = "./bert_dapt_model"
# 预训练超参数
DAPT_LR = 1e-5 # 较低的学习率,防止破坏原有知识
DAPT_EPOCHS = 3 # 适中的训练轮数
BATCH_SIZE = 16 # 批次大小 (请根据您的 GPU 显存调整)
MLM_PROBABILITY = 0.15 # 掩码比例
SEED = 42
NUM_PROC = 64 # 并行处理的进程数
# 设置随机种子以保证结果可复现
set_seed(SEED)
def domain_adaptive_pretrain():
# 路径检查
if not os.path.exists(DOMAIN_TEXT_FILE):
print(f"致命错误:领域语料库文件未找到在 {DOMAIN_TEXT_FILE}。请先运行数据预处理脚本。")
return
# 1. 加载模型和分词器
print("--- 1. 加载模型和分词器 ---")
# AutoTokenizer 会自动识别模型对应的分词器
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# AutoModelForMaskedLM 专门用于 MLM 任务
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
# 2. 加载和处理文本数据集
print("--- 2. 加载和处理文本数据集 ---")
# 使用 datasets 库加载纯文本文件
# 文件必须包含在 'train' 键下,以支持 Trainer
raw_datasets = load_dataset("text", data_files={"train": DOMAIN_TEXT_FILE})
# 定义 tokenization 函数
def tokenize_function(examples):
# 截断但不填充,因为 DataCollatorForLanguageModeling 会处理填充
return tokenizer(
examples["text"],
truncation=True,
max_length=512, # 推荐修改
return_special_tokens_mask=True
)
tokenized_datasets = raw_datasets.map(
tokenize_function, batched=True, remove_columns=["text"], num_proc=NUM_PROC
)
# 将长文本切块 (Chunking) 和分组 (Grouping)
def group_texts(examples):
# 拼接所有文本
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# 设定切块大小
chunk_size = 512
# print(f"Total length: {total_length}, after chunking: {total_length // chunk_size}")
# 通过截断 total_length 来丢弃最后一个不完整的切块
total_length = (total_length // chunk_size) * chunk_size
# 将文本切分成 max_length (512) 的块
result = {
k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
for k, t in concatenated_examples.items()
}
# 标签 ID 设为 input_ids,DataCollator 会将非掩码位置设置为 -100
result["labels"] = result["input_ids"].copy()
return result
# 最终的 DAPT 训练数据集
lm_datasets = tokenized_datasets.map(
group_texts, batched=True, num_proc=NUM_PROC
)
# # 3. 数据收集器 (动态掩码)
# # 这个 Collator 会在每个批次中随机应用 15% 的掩码
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=MLM_PROBABILITY
)
# 4. 设置训练参数
print("--- 3. 设置训练参数 ---")
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=DAPT_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
learning_rate=DAPT_LR,
weight_decay=0.01,
logging_steps=50,
save_strategy="epoch",
report_to="wandb",
)
# 5. 初始化 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"],
data_collator=data_collator,
)
# 6. 开始继续预训练
print("--- 4. 开始继续预训练 ---")
trainer.train()
# 7. 保存 DAPT 模型
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"DAPT 模型已保存至: {OUTPUT_DIR}")
if __name__ == "__main__":
domain_adaptive_pretrain()