File size: 11,191 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
Published baseline models for DailyAct-5M benchmark.

ASFormer: Transformer for Action Segmentation (Yi et al., BMVC 2021)
  - Multi-stage encoder-decoder transformer with dilated attention
  - For temporal action segmentation (Exp 2) and contact detection (Exp 3)

TinyHAR: Lightweight Deep Learning Model for HAR (Zhou et al., ISWC 2022 Best Paper)
  - Multi-scale temporal convolution + cross-channel attention + temporal pooling
  - Implemented as backbone in models.py for scene recognition (Exp 1)
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================
# Positional Encoding (shared)
# ============================================================

class PositionalEncoding1D(nn.Module):
    """Sinusoidal positional encoding."""

    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


# ============================================================
# ASFormer (Yi et al., BMVC 2021)
# ============================================================

class ConvFeedForward(nn.Module):
    """Position-wise convolution feed-forward used in ASFormer."""

    def __init__(self, d_model, kernel_size=3, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size, padding=kernel_size // 2)
        self.conv2 = nn.Conv1d(d_model * 2, d_model, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (B, T, D)
        residual = x
        x = self.norm(x)
        x = x.permute(0, 2, 1)  # (B, D, T)
        x = self.dropout(F.relu(self.conv1(x)))
        x = self.dropout(self.conv2(x))
        x = x.permute(0, 2, 1)  # (B, T, D)
        return residual + x


class DilatedAttention(nn.Module):
    """Multi-head self-attention with dilated temporal mask.

    At dilation d and window w, position t attends to positions
    {t + k*d : k in [-w, w]}, creating a hierarchical receptive field.
    """

    def __init__(self, d_model, dilation, num_heads=1, dropout=0.1, window_size=5):
        super().__init__()
        self.d_model = d_model
        self.dilation = dilation
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.norm = nn.LayerNorm(d_model)
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        # Cache for dilated masks
        self._mask_cache = {}

    def _get_dilated_mask(self, T, device):
        """Create or retrieve cached dilated attention mask."""
        key = (T, self.dilation, self.window_size, device)
        if key not in self._mask_cache:
            positions = torch.arange(T, device=device)
            diff = positions.unsqueeze(1) - positions.unsqueeze(0)  # (T, T)
            mask = torch.zeros(T, T, dtype=torch.bool, device=device)
            for w in range(-self.window_size, self.window_size + 1):
                mask |= (diff == w * self.dilation)
            self._mask_cache[key] = mask
        return self._mask_cache[key]

    def forward(self, x, cross_kv=None):
        # x: (B, T, D)
        B, T, D = x.shape
        residual = x
        x = self.norm(x)

        if cross_kv is not None:
            q = self.qkv(x)[:, :, :D]  # only use Q from x
            kv = self.qkv(cross_kv)[:, :, D:]  # K, V from cross_kv
            q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
            k = kv[:, :, :D].view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
            v = kv[:, :, D:].view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        else:
            qkv = self.qkv(x).view(B, T, 3, self.num_heads, self.head_dim)
            qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, T, head_dim)
            q, k, v = qkv[0], qkv[1], qkv[2]

        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale  # (B, H, T, T)

        # Apply dilated attention mask
        dilated_mask = self._get_dilated_mask(T, x.device)  # (T, T)
        attn = attn.masked_fill(~dilated_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, T, D)
        out = self.out_proj(out)
        return residual + self.dropout(out)


class ASFormerEncoderBlock(nn.Module):
    """Single encoder block: dilated self-attention + conv feed-forward."""

    def __init__(self, d_model, dilation, num_heads=1, kernel_size=3,
                 dropout=0.1, window_size=5):
        super().__init__()
        self.self_attn = DilatedAttention(d_model, dilation, num_heads, dropout, window_size)
        self.ffn = ConvFeedForward(d_model, kernel_size, dropout)

    def forward(self, x):
        x = self.self_attn(x)
        x = self.ffn(x)
        return x


class ASFormerDecoderBlock(nn.Module):
    """Single decoder block: self-attention + cross-attention + conv feed-forward."""

    def __init__(self, d_model, dilation, num_heads=1, kernel_size=3,
                 dropout=0.1, window_size=5):
        super().__init__()
        self.self_attn = DilatedAttention(d_model, dilation, num_heads, dropout, window_size)
        self.cross_attn = DilatedAttention(d_model, dilation, num_heads, dropout, window_size)
        self.ffn = ConvFeedForward(d_model, kernel_size, dropout)

    def forward(self, x, enc_features):
        x = self.self_attn(x)
        x = self.cross_attn(x, cross_kv=enc_features)
        x = self.ffn(x)
        return x


class ASFormerEncoder(nn.Module):
    """ASFormer encoder: projection + N dilated attention layers + output head."""

    def __init__(self, input_dim, d_model, num_classes, num_layers=5,
                 num_heads=1, kernel_size=3, dropout=0.1, window_size=5):
        super().__init__()
        self.input_proj = nn.Conv1d(input_dim, d_model, 1)
        self.pos_enc = PositionalEncoding1D(d_model, dropout)
        self.layers = nn.ModuleList([
            ASFormerEncoderBlock(d_model, 2 ** i, num_heads, kernel_size, dropout, window_size)
            for i in range(num_layers)
        ])
        self.output_proj = nn.Conv1d(d_model, num_classes, 1)

    def forward(self, x):
        # x: (B, T, C)
        x = x.permute(0, 2, 1)  # (B, C, T)
        x = self.input_proj(x)   # (B, d_model, T)
        x = x.permute(0, 2, 1)   # (B, T, d_model)
        x = self.pos_enc(x)

        for layer in self.layers:
            x = layer(x)

        features = x
        logits = self.output_proj(x.permute(0, 2, 1)).permute(0, 2, 1)  # (B, T, num_classes)
        return features, logits


class ASFormerDecoder(nn.Module):
    """ASFormer decoder: refinement stage with cross-attention to encoder."""

    def __init__(self, input_dim, d_model, num_classes, num_layers=5,
                 num_heads=1, kernel_size=3, dropout=0.1, window_size=5):
        super().__init__()
        self.input_proj = nn.Conv1d(input_dim, d_model, 1)
        self.pos_enc = PositionalEncoding1D(d_model, dropout)
        self.layers = nn.ModuleList([
            ASFormerDecoderBlock(d_model, 2 ** i, num_heads, kernel_size, dropout, window_size)
            for i in range(num_layers)
        ])
        self.output_proj = nn.Conv1d(d_model, num_classes, 1)

    def forward(self, dec_input, enc_features):
        # dec_input: (B, T, input_dim), enc_features: (B, T, d_model)
        x = dec_input.permute(0, 2, 1)
        x = self.input_proj(x)
        x = x.permute(0, 2, 1)
        x = self.pos_enc(x)

        for layer in self.layers:
            x = layer(x, enc_features)

        logits = self.output_proj(x.permute(0, 2, 1)).permute(0, 2, 1)
        return x, logits


class ASFormer(nn.Module):
    """ASFormer: Transformer for Action Segmentation (Yi et al., BMVC 2021).

    Multi-stage encoder-decoder transformer for frame-level action segmentation.
    Returns a list of per-stage logits for multi-stage training (same interface as MSTCN).

    Args:
        input_dim: Input feature dimension
        num_classes: Number of action classes
        hidden_dim: Hidden dimension (d_model)
        num_layers: Number of attention layers per stage (dilation 1, 2, ..., 2^(num_layers-1))
        num_decoders: Number of decoder (refinement) stages
        num_heads: Number of attention heads
        kernel_size: Feed-forward convolution kernel size
        dropout: Dropout rate
        window_size: Dilated attention window size
    """

    def __init__(self, input_dim, num_classes, hidden_dim=64, num_layers=5,
                 num_decoders=3, num_heads=1, kernel_size=3, dropout=0.1,
                 window_size=5):
        super().__init__()
        self.encoder = ASFormerEncoder(
            input_dim, hidden_dim, num_classes, num_layers,
            num_heads, kernel_size, dropout, window_size
        )
        self.decoders = nn.ModuleList([
            ASFormerDecoder(
                num_classes, hidden_dim, num_classes, num_layers,
                num_heads, kernel_size, dropout, window_size
            ) for _ in range(num_decoders)
        ])

    def forward(self, x):
        # x: (B, T, C)
        outputs = []
        enc_features, enc_logits = self.encoder(x)
        outputs.append(enc_logits)

        for decoder in self.decoders:
            dec_input = F.softmax(outputs[-1], dim=-1).detach()
            _, dec_logits = decoder(dec_input, enc_features)
            outputs.append(dec_logits)

        return outputs  # list of (B, T, num_classes), compatible with MSTCN interface


class ASFormerContact(nn.Module):
    """ASFormer adapted for binary contact detection (Exp 3).

    Wraps ASFormer to return only the final stage output (B, T, 2),
    compatible with the exp3 training loop.
    Uses multi-stage training internally but returns single output.
    """

    def __init__(self, input_dim, hidden_dim=64, num_layers=5, num_decoders=2,
                 num_heads=1, dropout=0.1):
        super().__init__()
        self.asformer = ASFormer(
            input_dim, num_classes=2, hidden_dim=hidden_dim,
            num_layers=num_layers, num_decoders=num_decoders,
            num_heads=num_heads, dropout=dropout
        )

    def forward(self, x):
        # x: (B, T, C) -> (B, T, 2)
        outputs = self.asformer(x)
        return outputs[-1]  # Return final stage only