""" 训练RobotTHP模型(带语义特征) 展示如何在EasyTPP框架中使用RobotTHP模型,并加载语义特征、偏差特征等 """ import torch from torch.utils.data import DataLoader from easy_tpp.config_factory import DataSpecConfig from easy_tpp.model import TorchRobotTHP from easy_tpp.preprocess.robert_dataset import RobertTPPDataset from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer from easy_tpp.preprocess.data_collator import TPPDataCollator def prepare_robert_data(): """ 准备评论罗伯特数据(示例) 实际使用时,应该从JSON文件加载并处理 """ # 示例数据 time_seqs = [ [0.0, 10.5, 25.3, 45.2], [0.0, 5.2, 12.8] ] type_seqs = [ [0, 1, 2, 1], # post, bot_comment, user_comment, user_comment [0, 1, 2] ] time_delta_seqs = [ [0.0, 10.5, 14.8, 19.9], [0.0, 5.2, 7.6] ] # 语义向量(示例:768维BERT向量) semantic_vectors = [ [[0.1] * 768, [0.2] * 768, [0.3] * 768, [0.4] * 768], [[0.1] * 768, [0.2] * 768, [0.3] * 768] ] # 偏差特征(示例:3维 [语境偏差, 情感偏差, 困惑度]) deviation_features = [ [[0.0, 0.0, 0.0], [0.7, 0.5, 0.3], [0.2, 0.1, 0.1], [0.3, 0.2, 0.1]], [[0.0, 0.0, 0.0], [0.6, 0.4, 0.2], [0.1, 0.1, 0.1]] ] # 自发/被@标记(-1=不适用, 0=被@, 1=自发) is_spontaneous = [ [-1.0, 1.0, -1.0, -1.0], # 原帖不适用, 罗伯特自发, 用户评论不适用 [-1.0, 0.0, -1.0] # 原帖不适用, 罗伯特被@, 用户评论不适用 ] return { 'time_seqs': time_seqs, 'type_seqs': type_seqs, 'time_delta_seqs': time_delta_seqs, 'semantic_vectors': semantic_vectors, 'deviation_features': deviation_features, 'is_spontaneous': is_spontaneous } def create_data_loader(data_dict, config, use_semantic=True, use_deviation=True): """ 创建数据加载器 Args: data_dict: 数据字典 config: 数据配置 use_semantic: 是否使用语义特征 use_deviation: 是否使用偏差特征 Returns: DataLoader: 数据加载器 """ # 创建数据集 dataset = RobertTPPDataset(data_dict) # 创建分词器 tokenizer = RobertEventTokenizer( config, use_semantic=use_semantic, use_deviation=use_deviation, semantic_dim=768 ) # 创建数据整理器 padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy data_collator = TPPDataCollator( tokenizer=tokenizer, return_tensors='pt', max_length=tokenizer.model_max_length, padding=padding, truncation=truncation ) # 创建数据加载器 data_loader = DataLoader( dataset, collate_fn=data_collator, batch_size=config.batch_size, shuffle=True ) return data_loader def main(): """主函数""" print("=" * 60) print("训练RobotTHP模型(带语义特征)") print("=" * 60) # 1. 准备数据 print("\n1. 准备数据...") data_dict = prepare_robert_data() print(f" 序列数: {len(data_dict['time_seqs'])}") # 2. 创建配置 print("\n2. 创建配置...") config = DataSpecConfig.parse_from_yaml_config({ 'num_event_types': 4, 'batch_size': 2, 'pad_token_id': 4 }) # 3. 创建数据加载器 print("\n3. 创建数据加载器...") data_loader = create_data_loader( data_dict, config, use_semantic=True, use_deviation=True ) # 4. 创建模型配置 print("\n4. 创建模型...") from easy_tpp.config_factory import ModelConfig model_config = ModelConfig.parse_from_yaml_config({ 'hidden_size': 128, 'num_layers': 3, 'num_heads': 6, 'dropout_rate': 0.1, 'num_event_types': 4, 'num_event_types_pad': 5, 'pad_token_id': 4, 'semantic_dim': 768, 'use_semantic': True, 'use_deviation': True, 'use_structure_mask': False, 'loss_integral_num_sample_per_step': 20, 'use_mc_samples': True, 'gpu': -1 }) model = TorchRobotTHP(model_config) print(f" 模型参数数量: {sum(p.numel() for p in model.parameters()):,}") # 5. 测试一个批次 print("\n5. 测试数据加载...") for batch in data_loader: # batch是BatchEncoding对象,需要转换为tuple/list batch_values = batch.values() print(f" 批次大小: {len(batch_values[0])}") print(f" 序列长度: {batch_values[0].shape[1]}") print(f" 时间序列形状: {batch_values[0].shape}") print(f" 事件类型形状: {batch_values[2].shape}") if len(batch_values) > 5: print(f" 语义向量形状: {batch_values[5].shape if batch_values[5] is not None else 'None'}") if len(batch_values) > 6: print(f" 偏差特征形状: {batch_values[6].shape if batch_values[6] is not None else 'None'}") if len(batch_values) > 7: print(f" 自发标记形状: {batch_values[7].shape if batch_values[7] is not None else 'None'}") # 6. 测试前向传播 print("\n6. 测试前向传播...") model.eval() with torch.no_grad(): loss, num_events = model.loglike_loss(batch_values) print(f" 损失值: {loss.item():.4f}") print(f" 事件数: {num_events}") break print("\n✅ 测试完成!") print("\n使用说明:") print("1. 将你的JSON数据转换为上述格式") print("2. 使用RobertTPPDataset和RobertEventTokenizer加载数据") print("3. 在EasyTPP配置文件中设置model_id为RobotTHP") print("4. 运行训练即可") if __name__ == '__main__': main()