|
|
""" |
|
|
训练RobotTHP模型(带语义特征) |
|
|
|
|
|
展示如何在EasyTPP框架中使用RobotTHP模型,并加载语义特征、偏差特征等 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from easy_tpp.config_factory import DataSpecConfig |
|
|
from easy_tpp.model import TorchRobotTHP |
|
|
from easy_tpp.preprocess.robert_dataset import RobertTPPDataset |
|
|
from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer |
|
|
from easy_tpp.preprocess.data_collator import TPPDataCollator |
|
|
|
|
|
|
|
|
def prepare_robert_data(): |
|
|
""" |
|
|
准备评论罗伯特数据(示例) |
|
|
|
|
|
实际使用时,应该从JSON文件加载并处理 |
|
|
""" |
|
|
|
|
|
time_seqs = [ |
|
|
[0.0, 10.5, 25.3, 45.2], |
|
|
[0.0, 5.2, 12.8] |
|
|
] |
|
|
type_seqs = [ |
|
|
[0, 1, 2, 1], |
|
|
[0, 1, 2] |
|
|
] |
|
|
time_delta_seqs = [ |
|
|
[0.0, 10.5, 14.8, 19.9], |
|
|
[0.0, 5.2, 7.6] |
|
|
] |
|
|
|
|
|
|
|
|
semantic_vectors = [ |
|
|
[[0.1] * 768, [0.2] * 768, [0.3] * 768, [0.4] * 768], |
|
|
[[0.1] * 768, [0.2] * 768, [0.3] * 768] |
|
|
] |
|
|
|
|
|
|
|
|
deviation_features = [ |
|
|
[[0.0, 0.0, 0.0], [0.7, 0.5, 0.3], [0.2, 0.1, 0.1], [0.3, 0.2, 0.1]], |
|
|
[[0.0, 0.0, 0.0], [0.6, 0.4, 0.2], [0.1, 0.1, 0.1]] |
|
|
] |
|
|
|
|
|
|
|
|
is_spontaneous = [ |
|
|
[-1.0, 1.0, -1.0, -1.0], |
|
|
[-1.0, 0.0, -1.0] |
|
|
] |
|
|
|
|
|
return { |
|
|
'time_seqs': time_seqs, |
|
|
'type_seqs': type_seqs, |
|
|
'time_delta_seqs': time_delta_seqs, |
|
|
'semantic_vectors': semantic_vectors, |
|
|
'deviation_features': deviation_features, |
|
|
'is_spontaneous': is_spontaneous |
|
|
} |
|
|
|
|
|
|
|
|
def create_data_loader(data_dict, config, use_semantic=True, use_deviation=True): |
|
|
""" |
|
|
创建数据加载器 |
|
|
|
|
|
Args: |
|
|
data_dict: 数据字典 |
|
|
config: 数据配置 |
|
|
use_semantic: 是否使用语义特征 |
|
|
use_deviation: 是否使用偏差特征 |
|
|
|
|
|
Returns: |
|
|
DataLoader: 数据加载器 |
|
|
""" |
|
|
|
|
|
dataset = RobertTPPDataset(data_dict) |
|
|
|
|
|
|
|
|
tokenizer = RobertEventTokenizer( |
|
|
config, |
|
|
use_semantic=use_semantic, |
|
|
use_deviation=use_deviation, |
|
|
semantic_dim=768 |
|
|
) |
|
|
|
|
|
|
|
|
padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy |
|
|
truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy |
|
|
|
|
|
data_collator = TPPDataCollator( |
|
|
tokenizer=tokenizer, |
|
|
return_tensors='pt', |
|
|
max_length=tokenizer.model_max_length, |
|
|
padding=padding, |
|
|
truncation=truncation |
|
|
) |
|
|
|
|
|
|
|
|
data_loader = DataLoader( |
|
|
dataset, |
|
|
collate_fn=data_collator, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True |
|
|
) |
|
|
|
|
|
return data_loader |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
print("=" * 60) |
|
|
print("训练RobotTHP模型(带语义特征)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n1. 准备数据...") |
|
|
data_dict = prepare_robert_data() |
|
|
print(f" 序列数: {len(data_dict['time_seqs'])}") |
|
|
|
|
|
|
|
|
print("\n2. 创建配置...") |
|
|
config = DataSpecConfig.parse_from_yaml_config({ |
|
|
'num_event_types': 4, |
|
|
'batch_size': 2, |
|
|
'pad_token_id': 4 |
|
|
}) |
|
|
|
|
|
|
|
|
print("\n3. 创建数据加载器...") |
|
|
data_loader = create_data_loader( |
|
|
data_dict, |
|
|
config, |
|
|
use_semantic=True, |
|
|
use_deviation=True |
|
|
) |
|
|
|
|
|
|
|
|
print("\n4. 创建模型...") |
|
|
from easy_tpp.config_factory import ModelConfig |
|
|
|
|
|
model_config = ModelConfig.parse_from_yaml_config({ |
|
|
'hidden_size': 128, |
|
|
'num_layers': 3, |
|
|
'num_heads': 6, |
|
|
'dropout_rate': 0.1, |
|
|
'num_event_types': 4, |
|
|
'num_event_types_pad': 5, |
|
|
'pad_token_id': 4, |
|
|
'semantic_dim': 768, |
|
|
'use_semantic': True, |
|
|
'use_deviation': True, |
|
|
'use_structure_mask': False, |
|
|
'loss_integral_num_sample_per_step': 20, |
|
|
'use_mc_samples': True, |
|
|
'gpu': -1 |
|
|
}) |
|
|
|
|
|
model = TorchRobotTHP(model_config) |
|
|
print(f" 模型参数数量: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
print("\n5. 测试数据加载...") |
|
|
for batch in data_loader: |
|
|
|
|
|
batch_values = batch.values() |
|
|
|
|
|
print(f" 批次大小: {len(batch_values[0])}") |
|
|
print(f" 序列长度: {batch_values[0].shape[1]}") |
|
|
print(f" 时间序列形状: {batch_values[0].shape}") |
|
|
print(f" 事件类型形状: {batch_values[2].shape}") |
|
|
|
|
|
if len(batch_values) > 5: |
|
|
print(f" 语义向量形状: {batch_values[5].shape if batch_values[5] is not None else 'None'}") |
|
|
if len(batch_values) > 6: |
|
|
print(f" 偏差特征形状: {batch_values[6].shape if batch_values[6] is not None else 'None'}") |
|
|
if len(batch_values) > 7: |
|
|
print(f" 自发标记形状: {batch_values[7].shape if batch_values[7] is not None else 'None'}") |
|
|
|
|
|
|
|
|
print("\n6. 测试前向传播...") |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
loss, num_events = model.loglike_loss(batch_values) |
|
|
print(f" 损失值: {loss.item():.4f}") |
|
|
print(f" 事件数: {num_events}") |
|
|
|
|
|
break |
|
|
|
|
|
print("\n✅ 测试完成!") |
|
|
print("\n使用说明:") |
|
|
print("1. 将你的JSON数据转换为上述格式") |
|
|
print("2. 使用RobertTPPDataset和RobertEventTokenizer加载数据") |
|
|
print("3. 在EasyTPP配置文件中设置model_id为RobotTHP") |
|
|
print("4. 运行训练即可") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|
|
|
|