File size: 4,566 Bytes
162ee52 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """
Sequence Analyzer β Bidirectional LSTM with Self-Attention.
Consumes a sequence of CNN feature embeddings and produces a single
temporal pattern encoding that captures heat-pattern evolution.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
"""
Additive (Bahdanau-style) self-attention over a sequence of hidden states.
Learns which timesteps are most informative and produces a
weighted context vector.
"""
def __init__(self, hidden_size: int):
super().__init__()
self.attention_fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.Tanh(),
nn.Linear(hidden_size // 2, 1, bias=False),
)
def forward(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
hidden_states: (B, T, H)
Returns:
context: (B, H) β weighted sum
weights: (B, T) β attention weights (for visualisation)
"""
scores = self.attention_fc(hidden_states).squeeze(-1) # (B, T)
weights = F.softmax(scores, dim=1) # (B, T)
context = torch.bmm(
weights.unsqueeze(1), hidden_states
).squeeze(1) # (B, H)
return context, weights
class SequenceAnalyzer(nn.Module):
"""
Bidirectional LSTM + Self-Attention for temporal analysis
of CNN feature sequences.
Architecture:
Input features (B, T, D)
β LayerNorm
β Bi-LSTM (2 layers, hidden=128)
β Self-Attention β context (B, 2*hidden)
β FC projection β (B, output_dim=256)
"""
def __init__(
self,
input_dim: int = 256,
hidden_size: int = 128,
num_layers: int = 2,
output_dim: int = 256,
bidirectional: bool = True,
dropout: float = 0.3,
use_attention: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.use_attention = use_attention
self.num_directions = 2 if bidirectional else 1
# Normalise input features
self.input_norm = nn.LayerNorm(input_dim)
# LSTM
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0.0,
)
lstm_output_dim = hidden_size * self.num_directions
# Attention
if self.use_attention:
self.attention = SelfAttention(lstm_output_dim)
# Projection to output_dim
self.projection = nn.Sequential(
nn.Linear(lstm_output_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
)
@classmethod
def from_config(cls, config) -> "SequenceAnalyzer":
"""Construct from a Config object."""
sa = config.model.sequence_analyzer
fe = config.model.feature_extractor
return cls(
input_dim=fe.embedding_dim,
hidden_size=sa.hidden_size,
num_layers=sa.num_layers,
output_dim=fe.embedding_dim,
bidirectional=sa.bidirectional,
dropout=sa.dropout,
use_attention=sa.attention,
)
def forward(
self, features: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Args:
features: (B, T, D) β sequence of CNN embeddings.
Returns:
encoding: (B, output_dim) β temporal pattern encoding.
attention_weights: (B, T) or None β per-timestep importance.
"""
# Normalise
normed = self.input_norm(features)
# LSTM
lstm_out, _ = self.lstm(normed) # (B, T, H*num_directions)
# Aggregate
if self.use_attention:
context, attn_weights = self.attention(lstm_out)
else:
# Fallback: use the last hidden state
context = lstm_out[:, -1, :]
attn_weights = None
# Project
encoding = self.projection(context)
return encoding, attn_weights
|