|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import List, Optional, Dict |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class SinusoidalPositionEmbeddings(nn.Module): |
|
|
def __init__(self, dim: int): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, time: torch.Tensor) -> torch.Tensor: |
|
|
device = time.device |
|
|
half_dim = self.dim // 2 |
|
|
embeddings = math.log(10000) / (half_dim - 1) |
|
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
|
embeddings = time[:, None] * embeddings[None, :] |
|
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
class ResnetBlock1D(nn.Module): |
|
|
def __init__(self, in_channels: int, out_channels: int, *, time_emb_dim: int = None, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.time_mlp = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(time_emb_dim, out_channels * 2) |
|
|
) if time_emb_dim is not None else None |
|
|
|
|
|
self.block1_conv = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
self.block1_norm = nn.GroupNorm(8, out_channels, affine=False) |
|
|
self.block1_act = nn.SiLU() |
|
|
|
|
|
self.block2_conv = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) |
|
|
self.block2_norm = nn.GroupNorm(8, out_channels) |
|
|
self.block2_act = nn.SiLU() |
|
|
self.block2_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.res_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() |
|
|
|
|
|
def forward(self, x: torch.Tensor, time_emb: torch.Tensor = None) -> torch.Tensor: |
|
|
h = self.block1_conv(x) |
|
|
h = self.block1_norm(h) |
|
|
|
|
|
if self.time_mlp is not None and time_emb is not None: |
|
|
scale_shift = self.time_mlp(time_emb) |
|
|
scale, shift = scale_shift.chunk(2, dim=1) |
|
|
h = h * (scale.unsqueeze(-1) + 1) + shift.unsqueeze(-1) |
|
|
|
|
|
h = self.block1_act(h) |
|
|
|
|
|
h = self.block2_act(self.block2_norm(self.block2_conv(h))) |
|
|
h = self.block2_dropout(h) |
|
|
return h + self.res_conv(x) |
|
|
|
|
|
|
|
|
class AttentionBlock1D(nn.Module): |
|
|
def __init__(self, channels: int, num_heads: int = 8): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
self.num_heads = num_heads |
|
|
assert channels % num_heads == 0, "channels must be divisible by num_heads" |
|
|
self.head_dim = channels // num_heads |
|
|
|
|
|
self.norm = nn.GroupNorm(8, channels) |
|
|
self.qkv = nn.Conv1d(channels, channels * 3, 1) |
|
|
self.proj = nn.Conv1d(channels, channels, 1) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, C, L = x.shape |
|
|
h = self.norm(x) |
|
|
|
|
|
qkv = self.qkv(h) |
|
|
qkv = qkv.view(B, 3, self.num_heads, self.head_dim, L) |
|
|
qkv = qkv.permute(1, 0, 2, 4, 3) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) |
|
|
|
|
|
out = out.permute(0, 1, 3, 2) |
|
|
out = out.contiguous().view(B, C, L) |
|
|
|
|
|
return x + self.proj(out) |
|
|
|
|
|
|
|
|
class DownBlock1D(nn.Module): |
|
|
def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2): |
|
|
super().__init__() |
|
|
self.resnets = nn.ModuleList([ |
|
|
ResnetBlock1D(in_channels if i == 0 else out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout) |
|
|
for i in range(num_blocks) |
|
|
]) |
|
|
self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity() |
|
|
self.downsampler = nn.Conv1d(out_channels, out_channels, kernel_size=4, stride=2, padding=1) |
|
|
|
|
|
def forward(self, x, time_emb): |
|
|
for resnet in self.resnets: |
|
|
x = resnet(x, time_emb) |
|
|
x = self.attn(x) |
|
|
skip = x |
|
|
x = self.downsampler(x) |
|
|
return x, skip |
|
|
|
|
|
|
|
|
class UpBlock1D(nn.Module): |
|
|
def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2): |
|
|
super().__init__() |
|
|
self.resnets = nn.ModuleList() |
|
|
self.resnets.append(ResnetBlock1D(in_channels * 2, out_channels, time_emb_dim=time_emb_dim, dropout=dropout)) |
|
|
for _ in range(num_blocks - 1): |
|
|
self.resnets.append(ResnetBlock1D(out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout)) |
|
|
self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity() |
|
|
self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) |
|
|
|
|
|
def forward(self, x, skip_x, time_emb): |
|
|
x = self.upsampler(x) |
|
|
|
|
|
if x.size(-1) != skip_x.size(-1): |
|
|
diff_L = skip_x.size(-1) - x.size(-1) |
|
|
if diff_L > 0: |
|
|
x = F.pad(x, [diff_L // 2, diff_L - diff_L // 2]) |
|
|
elif diff_L < 0: |
|
|
x = x[:, :, :skip_x.size(-1)] |
|
|
|
|
|
x = torch.cat([skip_x, x], dim=1) |
|
|
|
|
|
for resnet in self.resnets: |
|
|
x = resnet(x, time_emb) |
|
|
return self.attn(x) |
|
|
|
|
|
|
|
|
class ConditionalUnet(nn.Module): |
|
|
def __init__(self, in_channels: int, num_houses: int, embedding_dim: int = 64, |
|
|
hidden_dims: List[int] = [64, 128, 256], |
|
|
dropout: float = 0.1, use_attention: bool = True, |
|
|
cond_channels: int = 0, blocks_per_level: int = 2): |
|
|
super().__init__() |
|
|
time_emb_dim = hidden_dims[0] * 4 |
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
|
SinusoidalPositionEmbeddings(hidden_dims[0]), |
|
|
nn.Linear(hidden_dims[0], time_emb_dim), |
|
|
nn.SiLU(), |
|
|
nn.Linear(time_emb_dim, time_emb_dim) |
|
|
) |
|
|
|
|
|
self.house_embedding = nn.Embedding(num_houses, embedding_dim) |
|
|
self.house_proj = nn.Linear(embedding_dim, time_emb_dim) |
|
|
|
|
|
self.day_of_week_embedding = nn.Embedding(7, embedding_dim) |
|
|
self.day_of_year_embedding = nn.Embedding(366, embedding_dim) |
|
|
|
|
|
self.day_of_week_proj = nn.Linear(embedding_dim, time_emb_dim) |
|
|
self.day_of_year_proj = nn.Linear(embedding_dim, time_emb_dim) |
|
|
|
|
|
self.init_conv = nn.Conv1d(in_channels + cond_channels, hidden_dims[0], kernel_size=7, padding=3) |
|
|
|
|
|
num_resolutions = len(hidden_dims) |
|
|
self.down_blocks = nn.ModuleList([ |
|
|
DownBlock1D(hidden_dims[i], hidden_dims[i+1], time_emb_dim, dropout, use_attention, blocks_per_level) |
|
|
for i in range(num_resolutions - 1) |
|
|
]) |
|
|
|
|
|
self.mid_block1 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout) |
|
|
self.mid_attn = AttentionBlock1D(hidden_dims[-1]) |
|
|
self.mid_block2 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout) |
|
|
|
|
|
self.up_blocks = nn.ModuleList([ |
|
|
UpBlock1D(hidden_dims[i+1], hidden_dims[i], time_emb_dim, dropout, use_attention, blocks_per_level) |
|
|
for i in reversed(range(num_resolutions - 1)) |
|
|
]) |
|
|
|
|
|
self.final_conv = nn.Sequential( |
|
|
ResnetBlock1D(hidden_dims[0], hidden_dims[0], time_emb_dim=time_emb_dim, dropout=dropout), |
|
|
nn.Conv1d(hidden_dims[0], in_channels, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, timestep: torch.Tensor, conditions: Dict[str, torch.Tensor], |
|
|
conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
time_emb = self.time_mlp(timestep) |
|
|
|
|
|
house_id = conditions["house_id"] |
|
|
day_of_week = conditions["day_of_week"] |
|
|
day_of_year = conditions["day_of_year"] |
|
|
|
|
|
house_emb = self.house_proj(self.house_embedding(house_id)) |
|
|
dow_emb = self.day_of_week_proj(self.day_of_week_embedding(day_of_week)) |
|
|
doy_emb = self.day_of_year_proj(self.day_of_year_embedding(day_of_year)) |
|
|
|
|
|
emb = time_emb + house_emb + dow_emb + doy_emb |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
if conditioning_signal is not None: |
|
|
x = torch.cat([x, conditioning_signal.permute(0, 2, 1)], dim=1) |
|
|
|
|
|
x = self.init_conv(x) |
|
|
|
|
|
skip_connections = [] |
|
|
for down_block in self.down_blocks: |
|
|
x, skip_x = down_block(x, emb) |
|
|
skip_connections.append(skip_x) |
|
|
|
|
|
x = self.mid_block1(x, emb) |
|
|
x = self.mid_attn(x) |
|
|
x = self.mid_block2(x, emb) |
|
|
|
|
|
for up_block in self.up_blocks: |
|
|
x = up_block(x, skip_connections.pop(), emb) |
|
|
|
|
|
return self.final_conv(x).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
class ImprovedDiffusionModel(nn.Module): |
|
|
def __init__(self, base_model: ConditionalUnet, num_timesteps: int, channel_weights: torch.Tensor = None): |
|
|
super().__init__() |
|
|
self.model = base_model |
|
|
self.num_timesteps = num_timesteps |
|
|
self.channel_weights = channel_weights |
|
|
|
|
|
betas = self._cosine_beta_schedule(num_timesteps) |
|
|
alphas = 1.0 - betas |
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0) |
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
|
|
|
self.register_buffer('betas', betas) |
|
|
self.register_buffer('alphas', alphas) |
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod) |
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod)) |
|
|
|
|
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
|
|
posterior_variance = torch.clamp(posterior_variance, min=1e-20) |
|
|
self.register_buffer('posterior_variance', posterior_variance) |
|
|
|
|
|
def _cosine_beta_schedule(self, timesteps, s=0.008): |
|
|
steps = timesteps + 1 |
|
|
x = torch.linspace(0, timesteps, steps, dtype=torch.float64) |
|
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 |
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
|
|
return torch.clip(betas, 0.0001, 0.9999).float() |
|
|
|
|
|
def q_sample(self, x_start, t, noise=None): |
|
|
if noise is None: noise = torch.randn_like(x_start) |
|
|
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) |
|
|
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) |
|
|
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |
|
|
|
|
|
def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor], |
|
|
conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=x_0.device).long() |
|
|
noise = torch.randn_like(x_0) |
|
|
x_t = self.q_sample(x_0, t, noise) |
|
|
predicted_noise = self.model(x_t, t, conditions, conditioning_signal) |
|
|
|
|
|
|
|
|
loss = F.huber_loss(noise, predicted_noise, reduction='none') |
|
|
|
|
|
if self.channel_weights is not None: |
|
|
|
|
|
weights = self.channel_weights.to(loss.device).view(1, 1, -1) |
|
|
loss = (loss * weights).mean() |
|
|
else: |
|
|
loss = loss.mean() |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple, |
|
|
conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
device = next(self.model.parameters()).device |
|
|
x = torch.randn(num_samples, *shape, device=device) |
|
|
|
|
|
for t in tqdm(reversed(range(self.num_timesteps)), desc="Sampling", total=self.num_timesteps, leave=False): |
|
|
t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long) |
|
|
predicted_noise = self.model(x, t_batch, conditions, conditioning_signal) |
|
|
|
|
|
alpha_t = self.alphas[t] |
|
|
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t] |
|
|
|
|
|
mean = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * predicted_noise) |
|
|
|
|
|
if t > 0: |
|
|
noise = torch.randn_like(x) |
|
|
variance = self.posterior_variance[t] |
|
|
x = mean + torch.sqrt(variance) * noise |
|
|
else: |
|
|
x = mean |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class HierarchicalDiffusionModel(nn.Module): |
|
|
def __init__(self, in_channels: int, num_houses: int, downscale_factor: int, channel_weights: Optional[torch.Tensor] = None, **model_kwargs): |
|
|
super().__init__() |
|
|
self.downscale_factor = downscale_factor |
|
|
self.fine_chunk_size = 2 * 96 |
|
|
|
|
|
|
|
|
num_timesteps = model_kwargs.pop("num_timesteps") |
|
|
|
|
|
self.downsampler = nn.Conv1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor) |
|
|
self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor) |
|
|
|
|
|
|
|
|
self.coarse_model = ImprovedDiffusionModel( |
|
|
ConditionalUnet(in_channels=in_channels, num_houses=num_houses, **model_kwargs), |
|
|
num_timesteps, |
|
|
channel_weights=channel_weights |
|
|
) |
|
|
self.fine_model = ImprovedDiffusionModel( |
|
|
ConditionalUnet(in_channels=in_channels, num_houses=num_houses, |
|
|
cond_channels=in_channels, **model_kwargs), |
|
|
num_timesteps, |
|
|
channel_weights=channel_weights |
|
|
) |
|
|
|
|
|
def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
x_0_coarse = self.downsampler(x_0.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
coarse_loss = self.coarse_model(x_0_coarse, conditions) |
|
|
|
|
|
with torch.no_grad(): |
|
|
x_0_coarse_upsampled = self.upsampler(x_0_coarse.detach().permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
if x_0_coarse_upsampled.shape[1] != x_0.shape[1]: |
|
|
diff = x_0.shape[1] - x_0_coarse_upsampled.shape[1] |
|
|
if diff > 0: x_0_coarse_upsampled = F.pad(x_0_coarse_upsampled, [0, 0, 0, diff]) |
|
|
else: x_0_coarse_upsampled = x_0_coarse_upsampled[:, :x_0.shape[1], :] |
|
|
x_0_fine_residual = x_0 - x_0_coarse_upsampled |
|
|
|
|
|
full_length = x_0.shape[1] |
|
|
if full_length > self.fine_chunk_size: |
|
|
start_index = torch.randint(0, full_length - self.fine_chunk_size + 1, (1,)).item() |
|
|
else: |
|
|
start_index = 0 |
|
|
self.fine_chunk_size = full_length |
|
|
|
|
|
residual_chunk = x_0_fine_residual[:, start_index:start_index + self.fine_chunk_size, :] |
|
|
conditioning_chunk = x_0_coarse_upsampled[:, start_index:start_index + self.fine_chunk_size, :] |
|
|
|
|
|
fine_loss = self.fine_model(residual_chunk, conditions, conditioning_signal=conditioning_chunk) |
|
|
|
|
|
fine_loss_weight = 1.5 |
|
|
return coarse_loss + (fine_loss * fine_loss_weight) |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple) -> torch.Tensor: |
|
|
full_length, num_features = shape |
|
|
device = next(self.parameters()).device |
|
|
|
|
|
conditions = {k: v.to(device) for k, v in conditions.items()} |
|
|
|
|
|
print("--- Stage 1: Sampling Coarse Structure ---") |
|
|
coarse_shape = (full_length // self.downscale_factor, num_features) |
|
|
generated_coarse = self.coarse_model.sample(num_samples, conditions, shape=coarse_shape) |
|
|
upsampled_coarse = self.upsampler(generated_coarse.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
if upsampled_coarse.shape[1] != full_length: |
|
|
diff = full_length - upsampled_coarse.shape[1] |
|
|
if diff > 0: upsampled_coarse = F.pad(upsampled_coarse, [0, 0, 0, diff]) |
|
|
else: upsampled_coarse = upsampled_coarse[:, :full_length, :] |
|
|
|
|
|
print("--- Stage 2: Sampling Fine Details ---") |
|
|
stitched_fine_residual = torch.zeros_like(upsampled_coarse) |
|
|
|
|
|
for start_index in tqdm(range(0, full_length, self.fine_chunk_size), desc="Fine chunks"): |
|
|
end_index = min(start_index + self.fine_chunk_size, full_length) |
|
|
chunk_length = end_index - start_index |
|
|
fine_shape = (chunk_length, num_features) |
|
|
conditioning_chunk = upsampled_coarse[:, start_index:end_index, :] |
|
|
|
|
|
generated_fine_chunk = self.fine_model.sample( |
|
|
num_samples, conditions, shape=fine_shape, |
|
|
conditioning_signal=conditioning_chunk |
|
|
) |
|
|
|
|
|
stitched_fine_residual[:, start_index:end_index, :] = generated_fine_chunk |
|
|
|
|
|
final_sample = upsampled_coarse + stitched_fine_residual |
|
|
return final_sample |