""" 评论罗伯特数据集 - 支持语义特征、偏差特征等 扩展TPPDataset以支持多模态特征 """ from typing import Dict, Optional import numpy as np from easy_tpp.preprocess.dataset import TPPDataset from easy_tpp.utils import py_assert class RobertTPPDataset(TPPDataset): """ 支持语义特征、偏差特征等的TPP数据集 扩展标准TPPDataset以支持: - semantic_vectors: 语义向量列表 - deviation_features: 偏差特征列表 - is_spontaneous: 自发/被@标记列表 """ def __init__(self, data: Dict): """ 初始化数据集 Args: data: 数据字典,包含: - time_seqs: 时间序列列表 - type_seqs: 事件类型序列列表 - time_delta_seqs: 时间间隔序列列表 - semantic_vectors: 语义向量列表(可选)[num_seqs, seq_len, semantic_dim] - deviation_features: 偏差特征列表(可选)[num_seqs, seq_len, 3] - is_spontaneous: 自发/被@标记列表(可选)[num_seqs, seq_len] """ super(RobertTPPDataset, self).__init__(data) # 可选特征 self.semantic_vectors = self.data_dict.get('semantic_vectors', None) self.deviation_features = self.data_dict.get('deviation_features', None) self.is_spontaneous = self.data_dict.get('is_spontaneous', None) # 验证数据一致性 if self.semantic_vectors is not None: py_assert( len(self.semantic_vectors) == len(self.time_seqs), ValueError, f"Inconsistent lengths: semantic_vectors={len(self.semantic_vectors)}, " f"time_seqs={len(self.time_seqs)}" ) if self.deviation_features is not None: py_assert( len(self.deviation_features) == len(self.time_seqs), ValueError, f"Inconsistent lengths: deviation_features={len(self.deviation_features)}, " f"time_seqs={len(self.time_seqs)}" ) if self.is_spontaneous is not None: py_assert( len(self.is_spontaneous) == len(self.time_seqs), ValueError, f"Inconsistent lengths: is_spontaneous={len(self.is_spontaneous)}, " f"time_seqs={len(self.time_seqs)}" ) def __getitem__(self, idx): """ 获取单个样本 Args: idx: 样本索引 Returns: dict: 包含时间、类型、可选特征的字典 """ item = { 'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx], 'type_seqs': self.type_seqs[idx] } # 添加可选特征 if self.semantic_vectors is not None: item['semantic_vectors'] = self.semantic_vectors[idx] if self.deviation_features is not None: item['deviation_features'] = self.deviation_features[idx] if self.is_spontaneous is not None: item['is_spontaneous'] = self.is_spontaneous[idx] return item