|
|
""" |
|
|
评论罗伯特数据集 - 支持语义特征、偏差特征等 |
|
|
|
|
|
扩展TPPDataset以支持多模态特征 |
|
|
""" |
|
|
|
|
|
from typing import Dict, Optional |
|
|
import numpy as np |
|
|
from easy_tpp.preprocess.dataset import TPPDataset |
|
|
from easy_tpp.utils import py_assert |
|
|
|
|
|
|
|
|
class RobertTPPDataset(TPPDataset): |
|
|
""" |
|
|
支持语义特征、偏差特征等的TPP数据集 |
|
|
|
|
|
扩展标准TPPDataset以支持: |
|
|
- semantic_vectors: 语义向量列表 |
|
|
- deviation_features: 偏差特征列表 |
|
|
- is_spontaneous: 自发/被@标记列表 |
|
|
""" |
|
|
|
|
|
def __init__(self, data: Dict): |
|
|
""" |
|
|
初始化数据集 |
|
|
|
|
|
Args: |
|
|
data: 数据字典,包含: |
|
|
- time_seqs: 时间序列列表 |
|
|
- type_seqs: 事件类型序列列表 |
|
|
- time_delta_seqs: 时间间隔序列列表 |
|
|
- semantic_vectors: 语义向量列表(可选)[num_seqs, seq_len, semantic_dim] |
|
|
- deviation_features: 偏差特征列表(可选)[num_seqs, seq_len, 3] |
|
|
- is_spontaneous: 自发/被@标记列表(可选)[num_seqs, seq_len] |
|
|
""" |
|
|
super(RobertTPPDataset, self).__init__(data) |
|
|
|
|
|
|
|
|
self.semantic_vectors = self.data_dict.get('semantic_vectors', None) |
|
|
self.deviation_features = self.data_dict.get('deviation_features', None) |
|
|
self.is_spontaneous = self.data_dict.get('is_spontaneous', None) |
|
|
|
|
|
|
|
|
if self.semantic_vectors is not None: |
|
|
py_assert( |
|
|
len(self.semantic_vectors) == len(self.time_seqs), |
|
|
ValueError, |
|
|
f"Inconsistent lengths: semantic_vectors={len(self.semantic_vectors)}, " |
|
|
f"time_seqs={len(self.time_seqs)}" |
|
|
) |
|
|
|
|
|
if self.deviation_features is not None: |
|
|
py_assert( |
|
|
len(self.deviation_features) == len(self.time_seqs), |
|
|
ValueError, |
|
|
f"Inconsistent lengths: deviation_features={len(self.deviation_features)}, " |
|
|
f"time_seqs={len(self.time_seqs)}" |
|
|
) |
|
|
|
|
|
if self.is_spontaneous is not None: |
|
|
py_assert( |
|
|
len(self.is_spontaneous) == len(self.time_seqs), |
|
|
ValueError, |
|
|
f"Inconsistent lengths: is_spontaneous={len(self.is_spontaneous)}, " |
|
|
f"time_seqs={len(self.time_seqs)}" |
|
|
) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
获取单个样本 |
|
|
|
|
|
Args: |
|
|
idx: 样本索引 |
|
|
|
|
|
Returns: |
|
|
dict: 包含时间、类型、可选特征的字典 |
|
|
""" |
|
|
item = { |
|
|
'time_seqs': self.time_seqs[idx], |
|
|
'time_delta_seqs': self.time_delta_seqs[idx], |
|
|
'type_seqs': self.type_seqs[idx] |
|
|
} |
|
|
|
|
|
|
|
|
if self.semantic_vectors is not None: |
|
|
item['semantic_vectors'] = self.semantic_vectors[idx] |
|
|
|
|
|
if self.deviation_features is not None: |
|
|
item['deviation_features'] = self.deviation_features[idx] |
|
|
|
|
|
if self.is_spontaneous is not None: |
|
|
item['is_spontaneous'] = self.is_spontaneous[idx] |
|
|
|
|
|
return item |
|
|
|
|
|
|