Spaces:
Sleeping
Sleeping
| """ | |
| BERT合规性分类器训练脚本 | |
| 使用Hugging Face Transformers训练序列分类模型 | |
| """ | |
| import os | |
| import json | |
| from pathlib import Path | |
| from typing import List, Dict | |
| import math | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorWithPadding | |
| ) | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| class ComplianceDataset(Dataset): | |
| """合规性分类数据集""" | |
| def __init__(self, data: List[Dict], tokenizer, label2id: Dict): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.label2id = label2id | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| encoding = self.tokenizer( | |
| item['text'], | |
| truncation=True, | |
| max_length=128, | |
| padding=False # DataCollator会处理padding | |
| ) | |
| encoding['label'] = self.label2id[item['label']] | |
| return encoding | |
| class BERTClassifierTrainer: | |
| """BERT分类器训练器""" | |
| def __init__( | |
| self, | |
| model_name: str, | |
| train_data_path: str, | |
| output_dir: str, | |
| val_data_path: str = None | |
| ): | |
| """ | |
| Args: | |
| model_name: 预训练模型名称 | |
| train_data_path: 训练数据路径 | |
| output_dir: 输出目录 | |
| val_data_path: 验证数据路径 | |
| """ | |
| self.model_name = model_name | |
| self.train_data_path = train_data_path | |
| self.val_data_path = val_data_path | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # 加载标签映射 | |
| label_mapping_path = Path(train_data_path).parent / 'label_mapping.json' | |
| if label_mapping_path.exists(): | |
| with open(label_mapping_path, 'r', encoding='utf-8') as f: | |
| label_mapping = json.load(f) | |
| self.label2id = label_mapping['label2id'] | |
| self.id2label = label_mapping['id2label'] | |
| self.num_labels = label_mapping['num_labels'] | |
| else: | |
| # 从数据中提取标签 | |
| self._extract_labels_from_data() | |
| print(f"标签数量: {self.num_labels}") | |
| print(f"标签映射: {self.label2id}") | |
| # 加载tokenizer | |
| print(f"\n加载tokenizer: {model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # 检测GPU | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"使用设备: {self.device}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| def _extract_labels_from_data(self): | |
| """从数据中提取标签""" | |
| print("\n从训练数据中提取标签...") | |
| # 加载训练数据 | |
| train_data = self._load_data(self.train_data_path) | |
| # 提取唯一标签 | |
| unique_labels = set(item['label'] for item in train_data) | |
| unique_labels = sorted(unique_labels) | |
| self.label2id = {label: idx for idx, label in enumerate(unique_labels)} | |
| self.id2label = {idx: label for label, idx in self.label2id.items()} | |
| self.num_labels = len(unique_labels) | |
| # 保存标签映射 | |
| label_mapping = { | |
| 'label2id': self.label2id, | |
| 'id2label': self.id2label, | |
| 'num_labels': self.num_labels | |
| } | |
| mapping_path = self.output_dir / 'label_mapping.json' | |
| with open(mapping_path, 'w', encoding='utf-8') as f: | |
| json.dump(label_mapping, f, ensure_ascii=False, indent=2) | |
| def _load_data(self, data_path: str) -> List[Dict]: | |
| """加载数据""" | |
| print(f"加载数据: {data_path}") | |
| data_path = Path(data_path) | |
| data = [] | |
| if data_path.suffix == '.jsonl': | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| data.append(json.loads(line.strip())) | |
| else: | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| print(f"加载样本数: {len(data)}") | |
| return data | |
| def train( | |
| self, | |
| num_epochs: int = 3, | |
| batch_size: int = 16, | |
| learning_rate: float = 2e-5, | |
| warmup_steps: int = 100, | |
| weight_decay: float = 0.01 | |
| ): | |
| """训练模型""" | |
| print("\n" + "="*50) | |
| print("开始训练") | |
| print("="*50) | |
| # 加载训练数据 | |
| train_data = self._load_data(self.train_data_path) | |
| train_dataset = ComplianceDataset( | |
| train_data, | |
| self.tokenizer, | |
| self.label2id | |
| ) | |
| # 加载验证数据 | |
| eval_dataset = None | |
| if self.val_data_path and Path(self.val_data_path).exists(): | |
| val_data = self._load_data(self.val_data_path) | |
| eval_dataset = ComplianceDataset( | |
| val_data, | |
| self.tokenizer, | |
| self.label2id | |
| ) | |
| # 加载模型 | |
| print(f"\n加载模型: {self.model_name}") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| self.model_name, | |
| num_labels=self.num_labels, | |
| id2label=self.id2label, | |
| label2id=self.label2id | |
| ) | |
| # 训练参数 | |
| total_steps = math.ceil(len(train_dataset) / batch_size) * num_epochs | |
| training_args = TrainingArguments( | |
| output_dir=str(self.output_dir), | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| warmup_steps=warmup_steps, | |
| weight_decay=weight_decay, | |
| learning_rate=learning_rate, | |
| logging_dir=str(self.output_dir / 'logs'), | |
| logging_steps=10, | |
| eval_strategy="steps" if eval_dataset else "no", | |
| eval_steps=50 if eval_dataset else None, | |
| save_strategy="steps", | |
| save_steps=50, | |
| save_total_limit=3, | |
| load_best_model_at_end=True if eval_dataset else False, | |
| metric_for_best_model="f1" if eval_dataset else None, | |
| greater_is_better=True, | |
| report_to=None, # 不使用wandb/tensorboard | |
| ) | |
| print(f"\n训练配置:") | |
| print(f" 训练样本: {len(train_dataset)}") | |
| print(f" 批次大小: {batch_size}") | |
| print(f" 训练轮数: {num_epochs}") | |
| print(f" 总步数: {total_steps}") | |
| print(f" 学习率: {learning_rate}") | |
| # 创建Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| data_collator=DataCollatorWithPadding(self.tokenizer), | |
| compute_metrics=self._compute_metrics if eval_dataset else None, | |
| ) | |
| # 开始训练 | |
| print("\n开始训练...") | |
| trainer.train() | |
| # 保存最终模型 | |
| final_model_path = self.output_dir / 'final_model' | |
| trainer.save_model(str(final_model_path)) | |
| self.tokenizer.save_pretrained(str(final_model_path)) | |
| print(f"\n✓ 最终模型已保存: {final_model_path}") | |
| return trainer | |
| def _compute_metrics(self, eval_pred): | |
| """计算评估指标""" | |
| predictions, labels = eval_pred | |
| predictions = predictions.argmax(axis=-1) | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| labels, predictions, average='weighted' | |
| ) | |
| accuracy = accuracy_score(labels, predictions) | |
| return { | |
| 'accuracy': accuracy, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall | |
| } | |
| def evaluate(self, test_data_path: str = None): | |
| """评估模型""" | |
| if test_data_path is None: | |
| test_data_path = self.val_data_path | |
| if test_data_path is None or not Path(test_data_path).exists(): | |
| print("\n未提供测试数据,跳过评估") | |
| return | |
| print(f"\n评估模型: {test_data_path}") | |
| # 加载模型 | |
| model_path = self.output_dir / 'final_model' | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| # 加载数据 | |
| test_data = self._load_data(test_data_path) | |
| test_dataset = ComplianceDataset( | |
| test_data, | |
| self.tokenizer, | |
| self.label2id | |
| ) | |
| # 创建Trainer | |
| trainer = Trainer( | |
| model=model, | |
| data_collator=DataCollatorWithPadding(self.tokenizer) | |
| ) | |
| # 评估 | |
| results = trainer.evaluate(test_dataset) | |
| print(f"\n评估结果:") | |
| for key, value in results.items(): | |
| if 'eval_' in key: | |
| print(f" {key.replace('eval_', '')}: {value:.4f}") | |
| return results | |
| def main(): | |
| """主函数""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description='训练BERT合规性分类器') | |
| parser.add_argument( | |
| '--model_name', | |
| type=str, | |
| default='hfl/chinese-bert-wwm-ext', | |
| help='预训练模型名称' | |
| ) | |
| parser.add_argument( | |
| '--train_data', | |
| type=str, | |
| required=True, | |
| help='训练数据路径' | |
| ) | |
| parser.add_argument( | |
| '--val_data', | |
| type=str, | |
| default=None, | |
| help='验证数据路径' | |
| ) | |
| parser.add_argument( | |
| '--output_dir', | |
| type=str, | |
| default='models/bert-compliance', | |
| help='输出目录' | |
| ) | |
| parser.add_argument( | |
| '--num_epochs', | |
| type=int, | |
| default=3, | |
| help='训练轮数' | |
| ) | |
| parser.add_argument( | |
| '--batch_size', | |
| type=int, | |
| default=16, | |
| help='批次大小' | |
| ) | |
| parser.add_argument( | |
| '--learning_rate', | |
| type=float, | |
| default=2e-5, | |
| help='学习率' | |
| ) | |
| args = parser.parse_args() | |
| # 设置随机种子 | |
| import random | |
| import numpy as np | |
| random.seed(42) | |
| np.random.seed(42) | |
| torch.manual_seed(42) | |
| # 创建训练器 | |
| trainer = BERTClassifierTrainer( | |
| model_name=args.model_name, | |
| train_data_path=args.train_data, | |
| output_dir=args.output_dir, | |
| val_data_path=args.val_data | |
| ) | |
| # 训练 | |
| trainer.train( | |
| num_epochs=args.num_epochs, | |
| batch_size=args.batch_size, | |
| learning_rate=args.learning_rate | |
| ) | |
| # 评估 | |
| if args.val_data: | |
| trainer.evaluate() | |
| print("\n✓ 训练完成!") | |
| print(f"\n模型已保存到: {args.output_dir}") | |
| print("\n使用方法:") | |
| print(f" from transformers import AutoModelForSequenceClassification") | |
| print(f" model = AutoModelForSequenceClassification.from_pretrained('{args.output_dir}/final_model')") | |
| if __name__ == '__main__': | |
| main() | |