|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .configuration_aurora import AuroraConfig |
|
|
from .util_functions import sinusoidal_position_embedding, causal_attention_mask |
|
|
|
|
|
|
|
|
class PrototypeRetriever(nn.Module): |
|
|
def __init__(self, config: AuroraConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.num_prototypes = config.num_prototypes |
|
|
self.token_len = config.token_len |
|
|
|
|
|
|
|
|
|
|
|
self.prototypes = nn.Parameter(torch.empty(self.num_prototypes, self.token_len)) |
|
|
|
|
|
|
|
|
self._initialize_prototypes() |
|
|
|
|
|
self.retriever = Retriever(config) |
|
|
|
|
|
def _initialize_prototypes(self, random_seed=42): |
|
|
""" |
|
|
Initialize prototype parameters using diverse function generators. |
|
|
Adapted from the generate_prototypes logic to fit the class structure. |
|
|
""" |
|
|
|
|
|
np.random.seed(random_seed) |
|
|
|
|
|
length = self.token_len |
|
|
|
|
|
x = np.linspace(0, 10, length) |
|
|
|
|
|
prototypes_list = [] |
|
|
|
|
|
|
|
|
def generate_sin(): |
|
|
"""Generate sine function features""" |
|
|
freq = np.random.uniform(0.3, 2.0) |
|
|
amp = np.random.uniform(0.5, 2.0) |
|
|
phase = np.random.uniform(0, np.pi) |
|
|
return amp * np.sin(freq * x + phase) |
|
|
|
|
|
def generate_cos(): |
|
|
"""Generate cosine function features""" |
|
|
freq = np.random.uniform(0.3, 2.0) |
|
|
amp = np.random.uniform(0.5, 2.0) |
|
|
phase = np.random.uniform(0, np.pi) |
|
|
return amp * np.cos(freq * x + phase) |
|
|
|
|
|
def generate_log(): |
|
|
"""Generate logarithmic function features (trend)""" |
|
|
|
|
|
x_log = x + np.random.uniform(0.5, 2.0) |
|
|
slope = np.random.uniform(0.3, 1.5) |
|
|
offset = np.random.uniform(-2.0, 2.0) |
|
|
return slope * np.log(x_log) + offset |
|
|
|
|
|
def generate_exponential(): |
|
|
"""Generate exponential function features (trend)""" |
|
|
|
|
|
growth = np.random.uniform(-0.3, 0.3) |
|
|
amp = np.random.uniform(0.5, 2.0) |
|
|
return amp * np.exp(growth * x) |
|
|
|
|
|
def generate_linear(): |
|
|
"""Generate linear function features (trend)""" |
|
|
slope = np.random.uniform(-1.0, 1.0) |
|
|
intercept = np.random.uniform(-2.0, 2.0) |
|
|
return slope * x + intercept |
|
|
|
|
|
def generate_combination(): |
|
|
"""Generate combined features from multiple functions""" |
|
|
|
|
|
weights = np.random.dirichlet(np.ones(3)) |
|
|
func1 = generate_sin() |
|
|
func2 = generate_linear() |
|
|
|
|
|
func3 = generate_exponential() if np.random.random() > 0.5 else generate_log() |
|
|
return weights[0] * func1 + weights[1] * func2 + weights[2] * func3 |
|
|
|
|
|
|
|
|
functions = [ |
|
|
(generate_sin, 0.2), |
|
|
(generate_cos, 0.2), |
|
|
(generate_log, 0.15), |
|
|
(generate_exponential, 0.15), |
|
|
(generate_linear, 0.1), |
|
|
(generate_combination, 0.2) |
|
|
] |
|
|
|
|
|
|
|
|
funcs, probs = zip(*functions) |
|
|
|
|
|
|
|
|
for _ in range(self.num_prototypes): |
|
|
|
|
|
func = np.random.choice(funcs, p=probs) |
|
|
prototype = func() |
|
|
|
|
|
|
|
|
noise_level = np.random.uniform(0.05, 0.2) |
|
|
noise = np.random.normal(0, noise_level, length) |
|
|
prototype += noise |
|
|
|
|
|
prototypes_list.append(prototype) |
|
|
|
|
|
|
|
|
prototypes_np = np.array(prototypes_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_data = torch.from_numpy(prototypes_np).float() |
|
|
self.prototypes.data.copy_(tensor_data) |
|
|
|
|
|
def forward(self, x, output_token_len): |
|
|
""" |
|
|
Args: |
|
|
x: Input representation with shape [B, k, d] |
|
|
Returns: |
|
|
synthetic_protos: [B, F, p] (Normalized) |
|
|
""" |
|
|
|
|
|
dist = self.retriever(x, output_token_len) |
|
|
|
|
|
|
|
|
synthetic_protos = torch.matmul(dist, self.prototypes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean = synthetic_protos.mean(dim=-1, keepdim=True).detach() |
|
|
std = synthetic_protos.std(dim=-1, keepdim=True).detach() + 1e-5 |
|
|
synthetic_protos = (synthetic_protos - mean) / std |
|
|
|
|
|
return synthetic_protos |
|
|
|
|
|
|
|
|
class Retriever(nn.Module): |
|
|
def __init__(self, config: AuroraConfig): |
|
|
super().__init__() |
|
|
self.input_emb = nn.Sequential(nn.LayerNorm(config.hidden_size), |
|
|
nn.Linear(config.hidden_size, config.hidden_size)) |
|
|
self.encoder = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout_rate, |
|
|
batch_first=True, |
|
|
), |
|
|
norm=nn.LayerNorm(config.hidden_size), |
|
|
num_layers=config.num_retriever_enc_layers, |
|
|
) |
|
|
self.decoder = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout_rate, |
|
|
batch_first=True, |
|
|
), |
|
|
norm=nn.LayerNorm(config.hidden_size), |
|
|
num_layers=config.num_retriever_dec_layers, |
|
|
) |
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.intermediate_size), |
|
|
nn.LayerNorm(config.intermediate_size), |
|
|
nn.SiLU(), |
|
|
nn.Dropout(config.dropout_rate), |
|
|
nn.Linear(config.intermediate_size, config.num_prototypes), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
def forward(self, x, output_token_len): |
|
|
x_encoded = self.input_emb(x) |
|
|
enc_attn_mask = causal_attention_mask(x.shape[1]).to(x.device) |
|
|
enc_output = self.encoder(x_encoded, mask=enc_attn_mask.squeeze(0).squeeze(0)) |
|
|
|
|
|
enc_output = enc_output[:, -1:, :] |
|
|
|
|
|
dec = enc_output.repeat(1, output_token_len, 1) |
|
|
|
|
|
pos_embeds = sinusoidal_position_embedding( |
|
|
batch_size=dec.shape[0], num_heads=1, |
|
|
max_len=output_token_len, output_dim=self.hidden_size, |
|
|
device=dec.device).squeeze(1) |
|
|
|
|
|
embeds = dec + pos_embeds |
|
|
|
|
|
dec_attn_mask = causal_attention_mask(output_token_len).to(x.device) |
|
|
dec_output = self.decoder(embeds, mask=dec_attn_mask.squeeze(0).squeeze(0)) |
|
|
|
|
|
dist = self.head(dec_output) |
|
|
|
|
|
return dist |
|
|
|