| import torch |
| from torch import nn |
| from utils.position_coding import ( |
| LearnablePositionalEmbedding, |
| RotaryPositionalEncoding, |
| SinusoidalPositionalEncoding, |
| ) |
|
|
|
|
| class QFormerBlock(nn.Module): |
| def __init__(self, dim, num_heads, dropout=0.1, mlp_ratio=4.0): |
| super().__init__() |
| self.query_norm = nn.LayerNorm(dim) |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ) |
|
|
| self.cross_query_norm = nn.LayerNorm(dim) |
| self.memory_norm = nn.LayerNorm(dim) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ) |
|
|
| hidden_dim = int(dim * mlp_ratio) |
| self.ffn = nn.Sequential( |
| nn.LayerNorm(dim), |
| nn.Linear(dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, queries, memory): |
| query_residual = queries |
| norm_queries = self.query_norm(queries) |
| queries = query_residual + self.self_attn( |
| norm_queries, |
| norm_queries, |
| norm_queries, |
| need_weights=False, |
| )[0] |
|
|
| query_residual = queries |
| queries = query_residual + self.cross_attn( |
| self.cross_query_norm(queries), |
| self.memory_norm(memory), |
| self.memory_norm(memory), |
| need_weights=False, |
| )[0] |
|
|
| queries = queries + self.ffn(queries) |
| return queries |
|
|
|
|
| class QFormerAdapter(nn.Module): |
| """Stage-aware Q-Former adapter baseline. |
| |
| It keeps the Q-Former learnable query tokens, but preserves the time-series |
| structure before flattening memory for cross-attention. Stage 3 and 4 use |
| cycle positional encoding so cross-cycle degradation trends remain visible. |
| """ |
|
|
| def __init__(self, args): |
| super().__init__() |
| self.prefix_num = args.prefix_num |
| self.prefix_token = nn.Parameter(torch.randn(1, args.prefix_num, args.it_d_model)) |
| self.instruction_norm = nn.LayerNorm(args.it_d_model) |
| self.stage_embedding = nn.Embedding(5, args.it_d_model) |
| self.memory_norm = nn.LayerNorm(args.it_d_model) |
| self.time_pos = SinusoidalPositionalEncoding(args.it_d_model) |
| self.var_pos = LearnablePositionalEmbedding(args.it_d_model) |
| self.cycle_pos = RotaryPositionalEncoding(args.it_d_model) |
| self.layers = nn.ModuleList( |
| [ |
| QFormerBlock( |
| dim=args.it_d_model, |
| num_heads=args.it_n_heads, |
| dropout=args.it_dropout, |
| mlp_ratio=4.0, |
| ) |
| for _ in range(args.it_layers) |
| ] |
| ) |
| self.norm = nn.LayerNorm(args.it_d_model) |
|
|
| def _stage_list(self, stage, batch_size, device): |
| if stage is None: |
| return [0] * batch_size |
| if torch.is_tensor(stage): |
| values = stage.detach().to(device=device, dtype=torch.long).view(-1).tolist() |
| elif isinstance(stage, (int, float)): |
| values = [int(stage)] |
| else: |
| values = [int(s) for s in stage] |
|
|
| if len(values) == 1 and batch_size > 1: |
| return values * batch_size |
| if len(values) != batch_size: |
| raise ValueError(f"Expected {batch_size} stage values, got {len(values)}.") |
| return values |
|
|
| def _encode_memory(self, memory, stage): |
| batch_size = memory.shape[0] |
| stage_list = self._stage_list(stage, batch_size, memory.device) |
|
|
| cycle_index = [i for i, s in enumerate(stage_list) if s not in (3, 4)] |
| cross_cycle_index = [i for i, s in enumerate(stage_list) if s in (3, 4)] |
| original_indices = cycle_index + cross_cycle_index |
| reorder_map = {idx: i for i, idx in enumerate(original_indices)} |
| reverse_indices = [reorder_map[i] for i in range(batch_size)] |
|
|
| processed_memories = [] |
|
|
| if cycle_index: |
| sub_memory = memory[cycle_index] |
| b, l, v, d = sub_memory.shape |
| sub_memory = sub_memory.view(b * l, v, d) |
| sub_memory = sub_memory + self.time_pos(sub_memory) |
| sub_memory = sub_memory.view(b, l, v, d) |
| processed_memories.append(sub_memory) |
|
|
| if cross_cycle_index: |
| sub_memory = memory[cross_cycle_index] |
| b, l, v, d = sub_memory.shape |
| sub_memory = sub_memory.view(b * v, l, d) |
| sub_memory = sub_memory + self.cycle_pos(sub_memory) |
| sub_memory = sub_memory.view(b, l, v, d) |
| processed_memories.append(sub_memory) |
|
|
| memory = torch.cat(processed_memories, dim=0)[reverse_indices] |
|
|
| b, l, v, d = memory.shape |
| memory = memory.view(b * l, v, d) |
| memory = memory + self.var_pos(memory) |
| memory = memory.view(b, l * v, d) |
| return self.memory_norm(memory) |
|
|
| def forward(self, x, memory, stage=None, attn_mask=None): |
| batch_size = x.shape[0] |
| queries = self.prefix_token.repeat(batch_size, 1, 1) |
|
|
| if x is not None and x.numel() > 0: |
| instruction_bias = self.instruction_norm(x).mean(dim=1, keepdim=True) |
| queries = queries + instruction_bias |
|
|
| stage_ids = torch.tensor( |
| self._stage_list(stage, batch_size, queries.device), |
| device=queries.device, |
| dtype=torch.long, |
| ).clamp(min=0, max=self.stage_embedding.num_embeddings - 1) |
| queries = queries + self.stage_embedding(stage_ids).unsqueeze(1) |
|
|
| memory = self._encode_memory(memory, stage) |
|
|
| for layer in self.layers: |
| queries = layer(queries, memory) |
|
|
| return self.norm(queries) |
|
|