Abigail99216's picture
Upload folder using huggingface_hub
f43af3c verified

评论罗伯特特征使用指南

概述

本指南说明如何在EasyTPP框架中使用RobotTHP模型,并加载语义特征、偏差特征等自定义特征。

文件说明

  1. robert_dataset.py: 扩展的TPPDataset,支持加载语义特征、偏差特征等
  2. robert_tokenizer.py: 扩展的EventTokenizer,支持自定义特征的padding和批处理
  3. train_robot_thp_with_features.py: 完整的使用示例

数据格式

输入数据字典格式

data_dict = {
    'time_seqs': [[0.0, 10.5, 25.3], ...],  # 时间序列列表
    'type_seqs': [[0, 1, 2], ...],          # 事件类型序列列表
    'time_delta_seqs': [[0.0, 10.5, 14.8], ...],  # 时间间隔序列列表
    'semantic_vectors': [[[0.1]*768, [0.2]*768], ...],  # 语义向量(可选)
    'deviation_features': [[[0.0, 0.0, 0.0], [0.7, 0.5, 0.3]], ...],  # 偏差特征(可选)
    'is_spontaneous': [[-1.0, 1.0, -1.0], ...]  # 自发/被@标记(可选)
}

特征说明

  • semantic_vectors: [num_seqs, seq_len, semantic_dim],BERT语义向量
  • deviation_features: [num_seqs, seq_len, 3],偏差特征 [语境偏差, 情感偏差, 困惑度]
  • is_spontaneous: [num_seqs, seq_len],标记值:
    • -1.0: 不适用(非罗伯特评论)
    • 0.0: 被@(罗伯特被原帖作者@)
    • 1.0: 自发(罗伯特自发评论)

使用方法

1. 准备数据

from easy_tpp.preprocess.robert_dataset import RobertTPPDataset

# 准备数据字典
data_dict = {
    'time_seqs': [...],
    'type_seqs': [...],
    'time_delta_seqs': [...],
    'semantic_vectors': [...],  # 可选
    'deviation_features': [...],  # 可选
    'is_spontaneous': [...]  # 可选
}

# 创建数据集
dataset = RobertTPPDataset(data_dict)

2. 创建分词器

from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer
from easy_tpp.config_factory import DataSpecConfig

config = DataSpecConfig.parse_from_yaml_config({
    'num_event_types': 4,
    'batch_size': 32,
    'pad_token_id': 4
})

tokenizer = RobertEventTokenizer(
    config,
    use_semantic=True,      # 是否使用语义特征
    use_deviation=True,      # 是否使用偏差特征
    semantic_dim=768        # 语义向量维度
)

3. 创建数据加载器

from easy_tpp.preprocess.data_collator import TPPDataCollator
from torch.utils.data import DataLoader

data_collator = TPPDataCollator(
    tokenizer=tokenizer,
    return_tensors='pt',
    max_length=tokenizer.model_max_length,
    padding=True,
    truncation=False
)

data_loader = DataLoader(
    dataset,
    collate_fn=data_collator,
    batch_size=32,
    shuffle=True
)

4. 在模型中使用

RobotTHP模型的loglike_loss方法会自动从batch中提取这些特征:

from easy_tpp.model import TorchRobotTHP

model = TorchRobotTHP(model_config)

for batch in data_loader:
    batch_values = batch.values()  # 转换为tuple/list
    loss, num_events = model.loglike_loss(batch_values)

批次格式

批次数据格式(tuple/list):

batch = (
    time_seqs,           # [0] [batch_size, seq_len]
    time_delta_seqs,     # [1] [batch_size, seq_len]
    type_seqs,           # [2] [batch_size, seq_len]
    batch_non_pad_mask,  # [3] [batch_size, seq_len]
    attention_mask,      # [4] [batch_size, seq_len, seq_len]
    semantic_vectors,    # [5] [batch_size, seq_len, semantic_dim] (可选)
    deviation_features,  # [6] [batch_size, seq_len, 3] (可选)
    is_spontaneous,      # [7] [batch_size, seq_len] (可选)
    structure_mask       # [8] [batch_size, seq_len, seq_len] (可选)
)

完整示例

参考 examples/train_robot_thp_with_features.py 获取完整的使用示例。

注意事项

  1. 特征对齐: 确保所有特征序列的长度与时间序列一致
  2. Padding值:
    • 语义向量和偏差特征:padding使用0.0
    • is_spontaneous:padding使用-1.0(不适用)
  3. 可选特征: 如果某个特征未提供,模型会自动跳过该特征的处理
  4. 配置一致性: 确保模型配置中的use_semanticuse_deviation与tokenizer设置一致

与标准EasyTPP的集成

要完全集成到EasyTPP框架中,需要:

  1. 自定义数据加载器: 继承TPPDataLoader并重写_build_input_from_json方法
  2. 配置文件: 在配置文件中指定使用自定义数据集和分词器
  3. 模型配置: 设置use_semantic=Trueuse_deviation=True

从JSON文件加载

如果你的数据是JSON格式,可以参考以下方式加载:

import json
import numpy as np

# 加载JSON数据
with open('your_data.json', 'r') as f:
    json_data = json.load(f)

# 提取特征
time_seqs = [[event['time_since_start'] for event in seq] for seq in json_data]
type_seqs = [[event['type_event'] for event in seq] for seq in json_data]
time_delta_seqs = [[event['time_since_last_event'] for event in seq] for seq in json_data]

# 提取语义特征(如果存在)
semantic_vectors = None
if 'semantic_vectors' in json_data[0][0]:
    semantic_vectors = [[event['semantic_vectors'] for event in seq] for seq in json_data]

# 创建数据字典
data_dict = {
    'time_seqs': time_seqs,
    'type_seqs': type_seqs,
    'time_delta_seqs': time_delta_seqs,
    'semantic_vectors': semantic_vectors,
    # ... 其他特征
}