| | """Metadata embeddings - SinusoidalEmbedding + MLPs for metadata conditioning.""" |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class SinusoidalEmbedding(nn.Module): |
| | """Sinusoidal embedding for metadata.""" |
| |
|
| | def __init__(self, max_value, embedding_dim): |
| | super().__init__() |
| | self.max_value = max_value |
| | self.embedding_dim = embedding_dim |
| | self.omega = 10000.0 |
| |
|
| | def forward(self, k): |
| | device = k.device |
| | k_normalized = k * self.max_value |
| | embedding = torch.zeros( |
| | (k.size(0), k.size(1), self.embedding_dim), |
| | device=device, |
| | dtype=k.dtype, |
| | ) |
| | for j in range(k.size(1)): |
| | for i in range(self.embedding_dim // 2): |
| | omega_term = self.omega ** (-2 * i / self.embedding_dim) |
| | embedding[:, j, 2 * i] = torch.sin(k_normalized[:, j] * omega_term) |
| | embedding[:, j, 2 * i + 1] = torch.cos(k_normalized[:, j] * omega_term) |
| | return embedding.view(k.size(0), -1) |
| |
|
| |
|
| | def create_condition_vector(embedded_metadata, mlp_models, embedding_dim): |
| | """Create condition vector from metadata embeddings and MLPs.""" |
| | metadata_embeddings = [ |
| | mlp_models[j](embedded_metadata[:, j * embedding_dim : (j + 1) * embedding_dim]) |
| | for j in range(len(mlp_models)) |
| | ] |
| | return sum(metadata_embeddings) |
| |
|
| |
|
| | class MetadataMLP(nn.Module): |
| | def __init__(self, input_dim, embedding_dim): |
| | super().__init__() |
| | self.fc1 = nn.Linear(input_dim, embedding_dim) |
| |
|
| | def forward(self, x): |
| | return self.fc1(x) |
| |
|
| |
|
| | class MetadataEmbeddings(nn.Module): |
| | """Metadata embeddings - SinusoidalEmbedding + MLPs.""" |
| |
|
| | def __init__(self, max_value, embedding_dim, max_period, metadata_dim): |
| | super().__init__() |
| | self.sinusoidal_embedding = SinusoidalEmbedding(max_value, embedding_dim) |
| | self.mlp_models = nn.ModuleList([ |
| | MetadataMLP(embedding_dim, embedding_dim * 4) |
| | for _ in range(metadata_dim) |
| | ]) |
| | self.max_period = max_period |
| | self.embedding_dim = embedding_dim |
| | self.metadata_dim = metadata_dim |
| | self.max_value = max_value |
| |
|
| | def forward(self, metadata=None): |
| | while isinstance(metadata, (list, tuple)) and len(metadata) == 1: |
| | metadata = metadata[0] |
| | if metadata.dim() == 1: |
| | metadata = metadata.unsqueeze(0) |
| | embedded_metadata = self.sinusoidal_embedding(metadata) |
| | return create_condition_vector( |
| | embedded_metadata, self.mlp_models, self.embedding_dim |
| | ) |
| |
|
| |
|
| | |
| | metadata_embeddings = MetadataEmbeddings |
| |
|