import torch from torch import nn from utils.position_coding import ( LearnablePositionalEmbedding, RotaryPositionalEncoding, SinusoidalPositionalEncoding, ) class QFormerBlock(nn.Module): def __init__(self, dim, num_heads, dropout=0.1, mlp_ratio=4.0): super().__init__() self.query_norm = nn.LayerNorm(dim) self.self_attn = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.cross_query_norm = nn.LayerNorm(dim) self.memory_norm = nn.LayerNorm(dim) self.cross_attn = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) hidden_dim = int(dim * mlp_ratio) self.ffn = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout), ) def forward(self, queries, memory): query_residual = queries norm_queries = self.query_norm(queries) queries = query_residual + self.self_attn( norm_queries, norm_queries, norm_queries, need_weights=False, )[0] query_residual = queries queries = query_residual + self.cross_attn( self.cross_query_norm(queries), self.memory_norm(memory), self.memory_norm(memory), need_weights=False, )[0] queries = queries + self.ffn(queries) return queries class QFormerAdapter(nn.Module): """Stage-aware Q-Former adapter baseline. It keeps the Q-Former learnable query tokens, but preserves the time-series structure before flattening memory for cross-attention. Stage 3 and 4 use cycle positional encoding so cross-cycle degradation trends remain visible. """ def __init__(self, args): super().__init__() self.prefix_num = args.prefix_num self.prefix_token = nn.Parameter(torch.randn(1, args.prefix_num, args.it_d_model)) self.instruction_norm = nn.LayerNorm(args.it_d_model) self.stage_embedding = nn.Embedding(5, args.it_d_model) self.memory_norm = nn.LayerNorm(args.it_d_model) self.time_pos = SinusoidalPositionalEncoding(args.it_d_model) self.var_pos = LearnablePositionalEmbedding(args.it_d_model) self.cycle_pos = RotaryPositionalEncoding(args.it_d_model) self.layers = nn.ModuleList( [ QFormerBlock( dim=args.it_d_model, num_heads=args.it_n_heads, dropout=args.it_dropout, mlp_ratio=4.0, ) for _ in range(args.it_layers) ] ) self.norm = nn.LayerNorm(args.it_d_model) def _stage_list(self, stage, batch_size, device): if stage is None: return [0] * batch_size if torch.is_tensor(stage): values = stage.detach().to(device=device, dtype=torch.long).view(-1).tolist() elif isinstance(stage, (int, float)): values = [int(stage)] else: values = [int(s) for s in stage] if len(values) == 1 and batch_size > 1: return values * batch_size if len(values) != batch_size: raise ValueError(f"Expected {batch_size} stage values, got {len(values)}.") return values def _encode_memory(self, memory, stage): batch_size = memory.shape[0] stage_list = self._stage_list(stage, batch_size, memory.device) cycle_index = [i for i, s in enumerate(stage_list) if s not in (3, 4)] cross_cycle_index = [i for i, s in enumerate(stage_list) if s in (3, 4)] original_indices = cycle_index + cross_cycle_index reorder_map = {idx: i for i, idx in enumerate(original_indices)} reverse_indices = [reorder_map[i] for i in range(batch_size)] processed_memories = [] if cycle_index: sub_memory = memory[cycle_index] b, l, v, d = sub_memory.shape sub_memory = sub_memory.view(b * l, v, d) sub_memory = sub_memory + self.time_pos(sub_memory) sub_memory = sub_memory.view(b, l, v, d) processed_memories.append(sub_memory) if cross_cycle_index: sub_memory = memory[cross_cycle_index] b, l, v, d = sub_memory.shape sub_memory = sub_memory.view(b * v, l, d) sub_memory = sub_memory + self.cycle_pos(sub_memory) sub_memory = sub_memory.view(b, l, v, d) processed_memories.append(sub_memory) memory = torch.cat(processed_memories, dim=0)[reverse_indices] b, l, v, d = memory.shape memory = memory.view(b * l, v, d) memory = memory + self.var_pos(memory) memory = memory.view(b, l * v, d) return self.memory_norm(memory) def forward(self, x, memory, stage=None, attn_mask=None): batch_size = x.shape[0] queries = self.prefix_token.repeat(batch_size, 1, 1) if x is not None and x.numel() > 0: instruction_bias = self.instruction_norm(x).mean(dim=1, keepdim=True) queries = queries + instruction_bias stage_ids = torch.tensor( self._stage_list(stage, batch_size, queries.device), device=queries.device, dtype=torch.long, ).clamp(min=0, max=self.stage_embedding.num_embeddings - 1) queries = queries + self.stage_embedding(stage_ids).unsqueeze(1) memory = self._encode_memory(memory, stage) for layer in self.layers: queries = layer(queries, memory) return self.norm(queries)