Zorrojurro commited on
Commit
162ee52
Β·
verified Β·
1 Parent(s): 49c96b6

Upload src/models/sequence_analyzer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models/sequence_analyzer.py +147 -0
src/models/sequence_analyzer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sequence Analyzer β€” Bidirectional LSTM with Self-Attention.
3
+
4
+ Consumes a sequence of CNN feature embeddings and produces a single
5
+ temporal pattern encoding that captures heat-pattern evolution.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SelfAttention(nn.Module):
14
+ """
15
+ Additive (Bahdanau-style) self-attention over a sequence of hidden states.
16
+
17
+ Learns which timesteps are most informative and produces a
18
+ weighted context vector.
19
+ """
20
+
21
+ def __init__(self, hidden_size: int):
22
+ super().__init__()
23
+ self.attention_fc = nn.Sequential(
24
+ nn.Linear(hidden_size, hidden_size // 2),
25
+ nn.Tanh(),
26
+ nn.Linear(hidden_size // 2, 1, bias=False),
27
+ )
28
+
29
+ def forward(
30
+ self, hidden_states: torch.Tensor
31
+ ) -> tuple[torch.Tensor, torch.Tensor]:
32
+ """
33
+ Args:
34
+ hidden_states: (B, T, H)
35
+
36
+ Returns:
37
+ context: (B, H) β€” weighted sum
38
+ weights: (B, T) β€” attention weights (for visualisation)
39
+ """
40
+ scores = self.attention_fc(hidden_states).squeeze(-1) # (B, T)
41
+ weights = F.softmax(scores, dim=1) # (B, T)
42
+ context = torch.bmm(
43
+ weights.unsqueeze(1), hidden_states
44
+ ).squeeze(1) # (B, H)
45
+ return context, weights
46
+
47
+
48
+ class SequenceAnalyzer(nn.Module):
49
+ """
50
+ Bidirectional LSTM + Self-Attention for temporal analysis
51
+ of CNN feature sequences.
52
+
53
+ Architecture:
54
+ Input features (B, T, D)
55
+ β†’ LayerNorm
56
+ β†’ Bi-LSTM (2 layers, hidden=128)
57
+ β†’ Self-Attention β†’ context (B, 2*hidden)
58
+ β†’ FC projection β†’ (B, output_dim=256)
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ input_dim: int = 256,
64
+ hidden_size: int = 128,
65
+ num_layers: int = 2,
66
+ output_dim: int = 256,
67
+ bidirectional: bool = True,
68
+ dropout: float = 0.3,
69
+ use_attention: bool = True,
70
+ ):
71
+ super().__init__()
72
+ self.hidden_size = hidden_size
73
+ self.num_layers = num_layers
74
+ self.bidirectional = bidirectional
75
+ self.use_attention = use_attention
76
+ self.num_directions = 2 if bidirectional else 1
77
+
78
+ # Normalise input features
79
+ self.input_norm = nn.LayerNorm(input_dim)
80
+
81
+ # LSTM
82
+ self.lstm = nn.LSTM(
83
+ input_size=input_dim,
84
+ hidden_size=hidden_size,
85
+ num_layers=num_layers,
86
+ batch_first=True,
87
+ bidirectional=bidirectional,
88
+ dropout=dropout if num_layers > 1 else 0.0,
89
+ )
90
+
91
+ lstm_output_dim = hidden_size * self.num_directions
92
+
93
+ # Attention
94
+ if self.use_attention:
95
+ self.attention = SelfAttention(lstm_output_dim)
96
+
97
+ # Projection to output_dim
98
+ self.projection = nn.Sequential(
99
+ nn.Linear(lstm_output_dim, output_dim),
100
+ nn.BatchNorm1d(output_dim),
101
+ nn.ReLU(inplace=True),
102
+ nn.Dropout(p=dropout),
103
+ )
104
+
105
+ @classmethod
106
+ def from_config(cls, config) -> "SequenceAnalyzer":
107
+ """Construct from a Config object."""
108
+ sa = config.model.sequence_analyzer
109
+ fe = config.model.feature_extractor
110
+ return cls(
111
+ input_dim=fe.embedding_dim,
112
+ hidden_size=sa.hidden_size,
113
+ num_layers=sa.num_layers,
114
+ output_dim=fe.embedding_dim,
115
+ bidirectional=sa.bidirectional,
116
+ dropout=sa.dropout,
117
+ use_attention=sa.attention,
118
+ )
119
+
120
+ def forward(
121
+ self, features: torch.Tensor
122
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
123
+ """
124
+ Args:
125
+ features: (B, T, D) β€” sequence of CNN embeddings.
126
+
127
+ Returns:
128
+ encoding: (B, output_dim) β€” temporal pattern encoding.
129
+ attention_weights: (B, T) or None β€” per-timestep importance.
130
+ """
131
+ # Normalise
132
+ normed = self.input_norm(features)
133
+
134
+ # LSTM
135
+ lstm_out, _ = self.lstm(normed) # (B, T, H*num_directions)
136
+
137
+ # Aggregate
138
+ if self.use_attention:
139
+ context, attn_weights = self.attention(lstm_out)
140
+ else:
141
+ # Fallback: use the last hidden state
142
+ context = lstm_out[:, -1, :]
143
+ attn_weights = None
144
+
145
+ # Project
146
+ encoding = self.projection(context)
147
+ return encoding, attn_weights