| import os |
| import torch |
| import numpy as np |
| from datasets import load_dataset |
| import evaluate |
| from transformers import BertForSequenceClassification, BertTokenizerFast |
| from transformers import TrainingArguments, Trainer |
|
|
| |
| |
| DATA_DIR = "./processed_data_task1" |
| |
| MODEL_NAME = '/home/hsichen/part_time/BERT_finetune/outputs/finbert2_dapt_model' |
| |
| NUM_LABELS = 2 |
| |
| OUTPUT_DIR = "./finbert2_bilabel_finetuned_model_from_dapt" |
| |
| EPOCHS = 3 |
| BATCH_SIZE = 16 |
| LEARNING_RATE = 2e-5 |
| SEED = 42 |
|
|
| def compute_metrics(p): |
| """ |
| 计算评估指标 (准确率, F1, Precision, Recall) |
| """ |
| preds = np.argmax(p.predictions, axis=1) |
| labels = p.label_ids |
| |
| |
| metric = evaluate.load("accuracy") |
| accuracy = metric.compute(predictions=preds, references=labels)["accuracy"] |
| |
| |
| metric_f1 = evaluate.load("f1") |
| f1 = metric_f1.compute(predictions=preds, references=labels, average="binary")["f1"] |
| |
| return { |
| 'accuracy': accuracy, |
| 'f1': f1, |
| } |
|
|
|
|
| def finetune_bert(): |
| """ |
| 执行BERT模型的微调 |
| """ |
| |
| print("--- 1. 加载数据集 ---") |
| try: |
| |
| data_files = { |
| "train": os.path.join(DATA_DIR, "train.csv"), |
| "validation": os.path.join(DATA_DIR, "validation.csv"), |
| "test": os.path.join(DATA_DIR, "test.csv") |
| } |
| raw_datasets = load_dataset("csv", data_files=data_files) |
| print(raw_datasets) |
| except Exception as e: |
| print(f"加载数据集时发生错误,请检查 {DATA_DIR} 目录下的CSV文件: {e}") |
| return |
|
|
| |
| print("--- 2. 加载分词器和模型 ---") |
| tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME) |
| model = BertForSequenceClassification.from_pretrained( |
| MODEL_NAME, |
| num_labels=NUM_LABELS |
| ) |
|
|
| |
| def tokenize_function(examples): |
| |
| return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) |
|
|
| tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) |
|
|
| |
| train_dataset = tokenized_datasets["train"] |
| eval_dataset = tokenized_datasets["validation"] |
| test_dataset = tokenized_datasets["test"] |
|
|
| |
| print("--- 3. 设置训练参数和 Trainer ---") |
| training_args = TrainingArguments( |
| output_dir=OUTPUT_DIR, |
| num_train_epochs=EPOCHS, |
| per_device_train_batch_size=BATCH_SIZE, |
| per_device_eval_batch_size=BATCH_SIZE, |
| warmup_steps=500, |
| weight_decay=0.01, |
| logging_steps=50, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="f1", |
| seed=SEED, |
| learning_rate=LEARNING_RATE, |
| report_to="wandb" |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| compute_metrics=compute_metrics, |
| ) |
|
|
| |
| print("--- 4. 开始训练 ---") |
| trainer.train() |
| |
| |
| print("--- 5. 评估测试集 ---") |
| results = trainer.evaluate(test_dataset) |
| print(f"测试集评估结果: {results}") |
|
|
| |
| trainer.save_model(os.path.join(OUTPUT_DIR, "final")) |
| print(f"模型和分词器已保存至: {os.path.join(OUTPUT_DIR, 'final')}") |
|
|
| if __name__ == "__main__": |
| finetune_bert() |