File size: 8,012 Bytes
d5cfa8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import numpy as np
import torch
import torch.nn as nn
from .configuration_aurora import AuroraConfig
from .util_functions import sinusoidal_position_embedding, causal_attention_mask
class PrototypeRetriever(nn.Module):
def __init__(self, config: AuroraConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_prototypes = config.num_prototypes
self.token_len = config.token_len
# Define the learnable prototype parameter container.
# Initialize an empty Parameter first, to be filled in _initialize_prototypes.
self.prototypes = nn.Parameter(torch.empty(self.num_prototypes, self.token_len))
# Initialize prototypes using the new logic
self._initialize_prototypes()
self.retriever = Retriever(config)
def _initialize_prototypes(self, random_seed=42):
"""
Initialize prototype parameters using diverse function generators.
Adapted from the generate_prototypes logic to fit the class structure.
"""
# Set random seed for reproducibility
np.random.seed(random_seed)
length = self.token_len
# Create time series x, range from 0 to 10
x = np.linspace(0, 10, length)
prototypes_list = []
# --- Define internal generation functions ---
def generate_sin():
"""Generate sine function features"""
freq = np.random.uniform(0.3, 2.0)
amp = np.random.uniform(0.5, 2.0)
phase = np.random.uniform(0, np.pi)
return amp * np.sin(freq * x + phase)
def generate_cos():
"""Generate cosine function features"""
freq = np.random.uniform(0.3, 2.0)
amp = np.random.uniform(0.5, 2.0)
phase = np.random.uniform(0, np.pi)
return amp * np.cos(freq * x + phase)
def generate_log():
"""Generate logarithmic function features (trend)"""
# Ensure x is positive, suitable for log function
x_log = x + np.random.uniform(0.5, 2.0)
slope = np.random.uniform(0.3, 1.5)
offset = np.random.uniform(-2.0, 2.0)
return slope * np.log(x_log) + offset
def generate_exponential():
"""Generate exponential function features (trend)"""
# Can be positive or negative, allowing growth or decay
growth = np.random.uniform(-0.3, 0.3)
amp = np.random.uniform(0.5, 2.0)
return amp * np.exp(growth * x)
def generate_linear():
"""Generate linear function features (trend)"""
slope = np.random.uniform(-1.0, 1.0)
intercept = np.random.uniform(-2.0, 2.0)
return slope * x + intercept
def generate_combination():
"""Generate combined features from multiple functions"""
# Generate weights that sum to 1
weights = np.random.dirichlet(np.ones(3))
func1 = generate_sin()
func2 = generate_linear()
# Randomly select the third component
func3 = generate_exponential() if np.random.random() > 0.5 else generate_log()
return weights[0] * func1 + weights[1] * func2 + weights[2] * func3
# Function types and their probability distributions
functions = [
(generate_sin, 0.2),
(generate_cos, 0.2),
(generate_log, 0.15),
(generate_exponential, 0.15),
(generate_linear, 0.1),
(generate_combination, 0.2)
]
# Extract functions and corresponding probabilities
funcs, probs = zip(*functions)
# --- Prototype generation loop ---
for _ in range(self.num_prototypes):
# Randomly select function type based on probability
func = np.random.choice(funcs, p=probs)
prototype = func()
# Add some noise
noise_level = np.random.uniform(0.05, 0.2)
noise = np.random.normal(0, noise_level, length)
prototype += noise
prototypes_list.append(prototype)
# Convert to Numpy array
prototypes_np = np.array(prototypes_list)
# --- Key step: Convert to Tensor and assign to Parameter ---
# 1. Convert to Tensor
# 2. Convert to float32 (numpy defaults to float64, PyTorch typically uses float32)
# 3. Use .data.copy_ to fill nn.Parameter, maintaining the gradient tracking mechanism
tensor_data = torch.from_numpy(prototypes_np).float()
self.prototypes.data.copy_(tensor_data)
def forward(self, x, output_token_len):
"""
Args:
x: Input representation with shape [B, k, d]
Returns:
synthetic_protos: [B, F, p] (Normalized)
"""
# Calculate distribution [B, F, M]
dist = self.retriever(x, output_token_len)
# Weighted combination of prototypes [B, F, p]
synthetic_protos = torch.matmul(dist, self.prototypes)
# Normalize
# Note: Since the new initialization logic generates values with larger ranges and noise,
# Instance Normalization here is crucial for output stability.
mean = synthetic_protos.mean(dim=-1, keepdim=True).detach()
std = synthetic_protos.std(dim=-1, keepdim=True).detach() + 1e-5
synthetic_protos = (synthetic_protos - mean) / std
return synthetic_protos
class Retriever(nn.Module):
def __init__(self, config: AuroraConfig):
super().__init__()
self.input_emb = nn.Sequential(nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.hidden_size))
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout_rate,
batch_first=True,
),
norm=nn.LayerNorm(config.hidden_size),
num_layers=config.num_retriever_enc_layers,
)
self.decoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout_rate,
batch_first=True,
),
norm=nn.LayerNorm(config.hidden_size),
num_layers=config.num_retriever_dec_layers,
)
self.head = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size), # Combine context and position information
nn.LayerNorm(config.intermediate_size),
nn.SiLU(),
nn.Dropout(config.dropout_rate),
nn.Linear(config.intermediate_size, config.num_prototypes), # Predict prototype distribution
nn.Softmax(dim=-1)
)
self.hidden_size = config.hidden_size
def forward(self, x, output_token_len):
x_encoded = self.input_emb(x)
enc_attn_mask = causal_attention_mask(x.shape[1]).to(x.device)
enc_output = self.encoder(x_encoded, mask=enc_attn_mask.squeeze(0).squeeze(0)) # Shape: [B, k, d]
enc_output = enc_output[:, -1:, :]
dec = enc_output.repeat(1, output_token_len, 1)
pos_embeds = sinusoidal_position_embedding(
batch_size=dec.shape[0], num_heads=1,
max_len=output_token_len, output_dim=self.hidden_size,
device=dec.device).squeeze(1)
embeds = dec + pos_embeds
dec_attn_mask = causal_attention_mask(output_token_len).to(x.device)
dec_output = self.decoder(embeds, mask=dec_attn_mask.squeeze(0).squeeze(0))
dist = self.head(dec_output) # Shape: [B, F, M]
return dist
|