ITFormer / models /QFormerAdapter.py
a12354's picture
Add files using upload-large-folder tool
c8aad8f verified
Raw
History Blame Contribute Delete
5.92 kB
import torch
from torch import nn
from utils.position_coding import (
LearnablePositionalEmbedding,
RotaryPositionalEncoding,
SinusoidalPositionalEncoding,
)
class QFormerBlock(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, mlp_ratio=4.0):
super().__init__()
self.query_norm = nn.LayerNorm(dim)
self.self_attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.cross_query_norm = nn.LayerNorm(dim)
self.memory_norm = nn.LayerNorm(dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
hidden_dim = int(dim * mlp_ratio)
self.ffn = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, queries, memory):
query_residual = queries
norm_queries = self.query_norm(queries)
queries = query_residual + self.self_attn(
norm_queries,
norm_queries,
norm_queries,
need_weights=False,
)[0]
query_residual = queries
queries = query_residual + self.cross_attn(
self.cross_query_norm(queries),
self.memory_norm(memory),
self.memory_norm(memory),
need_weights=False,
)[0]
queries = queries + self.ffn(queries)
return queries
class QFormerAdapter(nn.Module):
"""Stage-aware Q-Former adapter baseline.
It keeps the Q-Former learnable query tokens, but preserves the time-series
structure before flattening memory for cross-attention. Stage 3 and 4 use
cycle positional encoding so cross-cycle degradation trends remain visible.
"""
def __init__(self, args):
super().__init__()
self.prefix_num = args.prefix_num
self.prefix_token = nn.Parameter(torch.randn(1, args.prefix_num, args.it_d_model))
self.instruction_norm = nn.LayerNorm(args.it_d_model)
self.stage_embedding = nn.Embedding(5, args.it_d_model)
self.memory_norm = nn.LayerNorm(args.it_d_model)
self.time_pos = SinusoidalPositionalEncoding(args.it_d_model)
self.var_pos = LearnablePositionalEmbedding(args.it_d_model)
self.cycle_pos = RotaryPositionalEncoding(args.it_d_model)
self.layers = nn.ModuleList(
[
QFormerBlock(
dim=args.it_d_model,
num_heads=args.it_n_heads,
dropout=args.it_dropout,
mlp_ratio=4.0,
)
for _ in range(args.it_layers)
]
)
self.norm = nn.LayerNorm(args.it_d_model)
def _stage_list(self, stage, batch_size, device):
if stage is None:
return [0] * batch_size
if torch.is_tensor(stage):
values = stage.detach().to(device=device, dtype=torch.long).view(-1).tolist()
elif isinstance(stage, (int, float)):
values = [int(stage)]
else:
values = [int(s) for s in stage]
if len(values) == 1 and batch_size > 1:
return values * batch_size
if len(values) != batch_size:
raise ValueError(f"Expected {batch_size} stage values, got {len(values)}.")
return values
def _encode_memory(self, memory, stage):
batch_size = memory.shape[0]
stage_list = self._stage_list(stage, batch_size, memory.device)
cycle_index = [i for i, s in enumerate(stage_list) if s not in (3, 4)]
cross_cycle_index = [i for i, s in enumerate(stage_list) if s in (3, 4)]
original_indices = cycle_index + cross_cycle_index
reorder_map = {idx: i for i, idx in enumerate(original_indices)}
reverse_indices = [reorder_map[i] for i in range(batch_size)]
processed_memories = []
if cycle_index:
sub_memory = memory[cycle_index]
b, l, v, d = sub_memory.shape
sub_memory = sub_memory.view(b * l, v, d)
sub_memory = sub_memory + self.time_pos(sub_memory)
sub_memory = sub_memory.view(b, l, v, d)
processed_memories.append(sub_memory)
if cross_cycle_index:
sub_memory = memory[cross_cycle_index]
b, l, v, d = sub_memory.shape
sub_memory = sub_memory.view(b * v, l, d)
sub_memory = sub_memory + self.cycle_pos(sub_memory)
sub_memory = sub_memory.view(b, l, v, d)
processed_memories.append(sub_memory)
memory = torch.cat(processed_memories, dim=0)[reverse_indices]
b, l, v, d = memory.shape
memory = memory.view(b * l, v, d)
memory = memory + self.var_pos(memory)
memory = memory.view(b, l * v, d)
return self.memory_norm(memory)
def forward(self, x, memory, stage=None, attn_mask=None):
batch_size = x.shape[0]
queries = self.prefix_token.repeat(batch_size, 1, 1)
if x is not None and x.numel() > 0:
instruction_bias = self.instruction_norm(x).mean(dim=1, keepdim=True)
queries = queries + instruction_bias
stage_ids = torch.tensor(
self._stage_list(stage, batch_size, queries.device),
device=queries.device,
dtype=torch.long,
).clamp(min=0, max=self.stage_embedding.num_embeddings - 1)
queries = queries + self.stage_embedding(stage_ids).unsqueeze(1)
memory = self._encode_memory(memory, stage)
for layer in self.layers:
queries = layer(queries, memory)
return self.norm(queries)