|
|
import os |
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoModelForMaskedLM, |
|
|
AutoTokenizer, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorForLanguageModeling, |
|
|
set_seed |
|
|
) |
|
|
|
|
|
|
|
|
DOMAIN_TEXT_FILE = "/home/hsichen/part_time/BERT_finetune/dataset_pretrain/domain_corpus.txt" |
|
|
MODEL_NAME = "valuesimplex-ai-lab/FinBERT2-base" |
|
|
OUTPUT_DIR = "./bert_dapt_model" |
|
|
|
|
|
|
|
|
DAPT_LR = 1e-5 |
|
|
DAPT_EPOCHS = 3 |
|
|
BATCH_SIZE = 16 |
|
|
MLM_PROBABILITY = 0.15 |
|
|
SEED = 42 |
|
|
NUM_PROC = 64 |
|
|
|
|
|
|
|
|
set_seed(SEED) |
|
|
|
|
|
def domain_adaptive_pretrain(): |
|
|
|
|
|
|
|
|
if not os.path.exists(DOMAIN_TEXT_FILE): |
|
|
print(f"致命错误:领域语料库文件未找到在 {DOMAIN_TEXT_FILE}。请先运行数据预处理脚本。") |
|
|
return |
|
|
|
|
|
|
|
|
print("--- 1. 加载模型和分词器 ---") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
print("--- 2. 加载和处理文本数据集 ---") |
|
|
|
|
|
|
|
|
|
|
|
raw_datasets = load_dataset("text", data_files={"train": DOMAIN_TEXT_FILE}) |
|
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
|
|
|
|
return tokenizer( |
|
|
examples["text"], |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_special_tokens_mask=True |
|
|
) |
|
|
|
|
|
tokenized_datasets = raw_datasets.map( |
|
|
tokenize_function, batched=True, remove_columns=["text"], num_proc=NUM_PROC |
|
|
) |
|
|
|
|
|
|
|
|
def group_texts(examples): |
|
|
|
|
|
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} |
|
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
|
|
|
chunk_size = 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_length = (total_length // chunk_size) * chunk_size |
|
|
|
|
|
|
|
|
result = { |
|
|
k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)] |
|
|
for k, t in concatenated_examples.items() |
|
|
} |
|
|
|
|
|
result["labels"] = result["input_ids"].copy() |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
lm_datasets = tokenized_datasets.map( |
|
|
group_texts, batched=True, num_proc=NUM_PROC |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=True, |
|
|
mlm_probability=MLM_PROBABILITY |
|
|
) |
|
|
|
|
|
|
|
|
print("--- 3. 设置训练参数 ---") |
|
|
training_args = TrainingArguments( |
|
|
output_dir=OUTPUT_DIR, |
|
|
num_train_epochs=DAPT_EPOCHS, |
|
|
per_device_train_batch_size=BATCH_SIZE, |
|
|
learning_rate=DAPT_LR, |
|
|
weight_decay=0.01, |
|
|
logging_steps=50, |
|
|
save_strategy="epoch", |
|
|
report_to="wandb", |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=lm_datasets["train"], |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
|
|
|
print("--- 4. 开始继续预训练 ---") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
trainer.save_model(OUTPUT_DIR) |
|
|
tokenizer.save_pretrained(OUTPUT_DIR) |
|
|
print(f"DAPT 模型已保存至: {OUTPUT_DIR}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
domain_adaptive_pretrain() |