File size: 3,302 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
评论罗伯特数据集 - 支持语义特征、偏差特征等

扩展TPPDataset以支持多模态特征
"""

from typing import Dict, Optional
import numpy as np
from easy_tpp.preprocess.dataset import TPPDataset
from easy_tpp.utils import py_assert


class RobertTPPDataset(TPPDataset):
    """
    支持语义特征、偏差特征等的TPP数据集
    
    扩展标准TPPDataset以支持:
    - semantic_vectors: 语义向量列表
    - deviation_features: 偏差特征列表
    - is_spontaneous: 自发/被@标记列表
    """
    
    def __init__(self, data: Dict):
        """
        初始化数据集
        
        Args:
            data: 数据字典,包含:
                - time_seqs: 时间序列列表
                - type_seqs: 事件类型序列列表
                - time_delta_seqs: 时间间隔序列列表
                - semantic_vectors: 语义向量列表(可选)[num_seqs, seq_len, semantic_dim]
                - deviation_features: 偏差特征列表(可选)[num_seqs, seq_len, 3]
                - is_spontaneous: 自发/被@标记列表(可选)[num_seqs, seq_len]
        """
        super(RobertTPPDataset, self).__init__(data)
        
        # 可选特征
        self.semantic_vectors = self.data_dict.get('semantic_vectors', None)
        self.deviation_features = self.data_dict.get('deviation_features', None)
        self.is_spontaneous = self.data_dict.get('is_spontaneous', None)
        
        # 验证数据一致性
        if self.semantic_vectors is not None:
            py_assert(
                len(self.semantic_vectors) == len(self.time_seqs),
                ValueError,
                f"Inconsistent lengths: semantic_vectors={len(self.semantic_vectors)}, "
                f"time_seqs={len(self.time_seqs)}"
            )
        
        if self.deviation_features is not None:
            py_assert(
                len(self.deviation_features) == len(self.time_seqs),
                ValueError,
                f"Inconsistent lengths: deviation_features={len(self.deviation_features)}, "
                f"time_seqs={len(self.time_seqs)}"
            )
        
        if self.is_spontaneous is not None:
            py_assert(
                len(self.is_spontaneous) == len(self.time_seqs),
                ValueError,
                f"Inconsistent lengths: is_spontaneous={len(self.is_spontaneous)}, "
                f"time_seqs={len(self.time_seqs)}"
            )
    
    def __getitem__(self, idx):
        """
        获取单个样本
        
        Args:
            idx: 样本索引
        
        Returns:
            dict: 包含时间、类型、可选特征的字典
        """
        item = {
            'time_seqs': self.time_seqs[idx],
            'time_delta_seqs': self.time_delta_seqs[idx],
            'type_seqs': self.type_seqs[idx]
        }
        
        # 添加可选特征
        if self.semantic_vectors is not None:
            item['semantic_vectors'] = self.semantic_vectors[idx]
        
        if self.deviation_features is not None:
            item['deviation_features'] = self.deviation_features[idx]
        
        if self.is_spontaneous is not None:
            item['is_spontaneous'] = self.is_spontaneous[idx]
        
        return item