File size: 3,302 Bytes
f43af3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
评论罗伯特数据集 - 支持语义特征、偏差特征等
扩展TPPDataset以支持多模态特征
"""
from typing import Dict, Optional
import numpy as np
from easy_tpp.preprocess.dataset import TPPDataset
from easy_tpp.utils import py_assert
class RobertTPPDataset(TPPDataset):
"""
支持语义特征、偏差特征等的TPP数据集
扩展标准TPPDataset以支持:
- semantic_vectors: 语义向量列表
- deviation_features: 偏差特征列表
- is_spontaneous: 自发/被@标记列表
"""
def __init__(self, data: Dict):
"""
初始化数据集
Args:
data: 数据字典,包含:
- time_seqs: 时间序列列表
- type_seqs: 事件类型序列列表
- time_delta_seqs: 时间间隔序列列表
- semantic_vectors: 语义向量列表(可选)[num_seqs, seq_len, semantic_dim]
- deviation_features: 偏差特征列表(可选)[num_seqs, seq_len, 3]
- is_spontaneous: 自发/被@标记列表(可选)[num_seqs, seq_len]
"""
super(RobertTPPDataset, self).__init__(data)
# 可选特征
self.semantic_vectors = self.data_dict.get('semantic_vectors', None)
self.deviation_features = self.data_dict.get('deviation_features', None)
self.is_spontaneous = self.data_dict.get('is_spontaneous', None)
# 验证数据一致性
if self.semantic_vectors is not None:
py_assert(
len(self.semantic_vectors) == len(self.time_seqs),
ValueError,
f"Inconsistent lengths: semantic_vectors={len(self.semantic_vectors)}, "
f"time_seqs={len(self.time_seqs)}"
)
if self.deviation_features is not None:
py_assert(
len(self.deviation_features) == len(self.time_seqs),
ValueError,
f"Inconsistent lengths: deviation_features={len(self.deviation_features)}, "
f"time_seqs={len(self.time_seqs)}"
)
if self.is_spontaneous is not None:
py_assert(
len(self.is_spontaneous) == len(self.time_seqs),
ValueError,
f"Inconsistent lengths: is_spontaneous={len(self.is_spontaneous)}, "
f"time_seqs={len(self.time_seqs)}"
)
def __getitem__(self, idx):
"""
获取单个样本
Args:
idx: 样本索引
Returns:
dict: 包含时间、类型、可选特征的字典
"""
item = {
'time_seqs': self.time_seqs[idx],
'time_delta_seqs': self.time_delta_seqs[idx],
'type_seqs': self.type_seqs[idx]
}
# 添加可选特征
if self.semantic_vectors is not None:
item['semantic_vectors'] = self.semantic_vectors[idx]
if self.deviation_features is not None:
item['deviation_features'] = self.deviation_features[idx]
if self.is_spontaneous is not None:
item['is_spontaneous'] = self.is_spontaneous[idx]
return item
|