SolarSys2025's picture
Upload 10 files
f884fa5 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Optional, Dict
from tqdm import tqdm
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, time: torch.Tensor) -> torch.Tensor:
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class ResnetBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, *, time_emb_dim: int = None, dropout: float = 0.1):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels * 2)
) if time_emb_dim is not None else None
self.block1_conv = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
self.block1_norm = nn.GroupNorm(8, out_channels, affine=False)
self.block1_act = nn.SiLU()
self.block2_conv = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
self.block2_norm = nn.GroupNorm(8, out_channels)
self.block2_act = nn.SiLU()
self.block2_dropout = nn.Dropout(dropout)
self.res_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
def forward(self, x: torch.Tensor, time_emb: torch.Tensor = None) -> torch.Tensor:
h = self.block1_conv(x)
h = self.block1_norm(h)
if self.time_mlp is not None and time_emb is not None:
scale_shift = self.time_mlp(time_emb)
scale, shift = scale_shift.chunk(2, dim=1)
h = h * (scale.unsqueeze(-1) + 1) + shift.unsqueeze(-1)
h = self.block1_act(h)
h = self.block2_act(self.block2_norm(self.block2_conv(h)))
h = self.block2_dropout(h)
return h + self.res_conv(x)
class AttentionBlock1D(nn.Module):
def __init__(self, channels: int, num_heads: int = 8):
super().__init__()
self.channels = channels
self.num_heads = num_heads
assert channels % num_heads == 0, "channels must be divisible by num_heads"
self.head_dim = channels // num_heads
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.proj = nn.Conv1d(channels, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, L = x.shape
h = self.norm(x)
qkv = self.qkv(h)
qkv = qkv.view(B, 3, self.num_heads, self.head_dim, L)
qkv = qkv.permute(1, 0, 2, 4, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
out = out.permute(0, 1, 3, 2)
out = out.contiguous().view(B, C, L)
return x + self.proj(out)
class DownBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock1D(in_channels if i == 0 else out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout)
for i in range(num_blocks)
])
self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity()
self.downsampler = nn.Conv1d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x, time_emb):
for resnet in self.resnets:
x = resnet(x, time_emb)
x = self.attn(x)
skip = x
x = self.downsampler(x)
return x, skip
class UpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2):
super().__init__()
self.resnets = nn.ModuleList()
self.resnets.append(ResnetBlock1D(in_channels * 2, out_channels, time_emb_dim=time_emb_dim, dropout=dropout))
for _ in range(num_blocks - 1):
self.resnets.append(ResnetBlock1D(out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout))
self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity()
self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x, skip_x, time_emb):
x = self.upsampler(x)
if x.size(-1) != skip_x.size(-1):
diff_L = skip_x.size(-1) - x.size(-1)
if diff_L > 0:
x = F.pad(x, [diff_L // 2, diff_L - diff_L // 2])
elif diff_L < 0:
x = x[:, :, :skip_x.size(-1)]
x = torch.cat([skip_x, x], dim=1)
for resnet in self.resnets:
x = resnet(x, time_emb)
return self.attn(x)
class ConditionalUnet(nn.Module):
def __init__(self, in_channels: int, num_houses: int, embedding_dim: int = 64,
hidden_dims: List[int] = [64, 128, 256],
dropout: float = 0.1, use_attention: bool = True,
cond_channels: int = 0, blocks_per_level: int = 2):
super().__init__()
time_emb_dim = hidden_dims[0] * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(hidden_dims[0]),
nn.Linear(hidden_dims[0], time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim)
)
self.house_embedding = nn.Embedding(num_houses, embedding_dim)
self.house_proj = nn.Linear(embedding_dim, time_emb_dim)
self.day_of_week_embedding = nn.Embedding(7, embedding_dim)
self.day_of_year_embedding = nn.Embedding(366, embedding_dim)
self.day_of_week_proj = nn.Linear(embedding_dim, time_emb_dim)
self.day_of_year_proj = nn.Linear(embedding_dim, time_emb_dim)
self.init_conv = nn.Conv1d(in_channels + cond_channels, hidden_dims[0], kernel_size=7, padding=3)
num_resolutions = len(hidden_dims)
self.down_blocks = nn.ModuleList([
DownBlock1D(hidden_dims[i], hidden_dims[i+1], time_emb_dim, dropout, use_attention, blocks_per_level)
for i in range(num_resolutions - 1)
])
self.mid_block1 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout)
self.mid_attn = AttentionBlock1D(hidden_dims[-1])
self.mid_block2 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout)
self.up_blocks = nn.ModuleList([
UpBlock1D(hidden_dims[i+1], hidden_dims[i], time_emb_dim, dropout, use_attention, blocks_per_level)
for i in reversed(range(num_resolutions - 1))
])
self.final_conv = nn.Sequential(
ResnetBlock1D(hidden_dims[0], hidden_dims[0], time_emb_dim=time_emb_dim, dropout=dropout),
nn.Conv1d(hidden_dims[0], in_channels, 1)
)
def forward(self, x: torch.Tensor, timestep: torch.Tensor, conditions: Dict[str, torch.Tensor],
conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor:
time_emb = self.time_mlp(timestep)
house_id = conditions["house_id"]
day_of_week = conditions["day_of_week"]
day_of_year = conditions["day_of_year"]
house_emb = self.house_proj(self.house_embedding(house_id))
dow_emb = self.day_of_week_proj(self.day_of_week_embedding(day_of_week))
doy_emb = self.day_of_year_proj(self.day_of_year_embedding(day_of_year))
emb = time_emb + house_emb + dow_emb + doy_emb
x = x.permute(0, 2, 1)
if conditioning_signal is not None:
x = torch.cat([x, conditioning_signal.permute(0, 2, 1)], dim=1)
x = self.init_conv(x)
skip_connections = []
for down_block in self.down_blocks:
x, skip_x = down_block(x, emb)
skip_connections.append(skip_x)
x = self.mid_block1(x, emb)
x = self.mid_attn(x)
x = self.mid_block2(x, emb)
for up_block in self.up_blocks:
x = up_block(x, skip_connections.pop(), emb)
return self.final_conv(x).permute(0, 2, 1)
class ImprovedDiffusionModel(nn.Module):
def __init__(self, base_model: ConditionalUnet, num_timesteps: int, channel_weights: torch.Tensor = None):
super().__init__()
self.model = base_model
self.num_timesteps = num_timesteps
self.channel_weights = channel_weights
betas = self._cosine_beta_schedule(num_timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
self.register_buffer('betas', betas)
self.register_buffer('alphas', alphas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_variance = torch.clamp(posterior_variance, min=1e-20)
self.register_buffer('posterior_variance', posterior_variance)
def _cosine_beta_schedule(self, timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999).float()
def q_sample(self, x_start, t, noise=None):
if noise is None: noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor],
conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor:
t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=x_0.device).long()
noise = torch.randn_like(x_0)
x_t = self.q_sample(x_0, t, noise)
predicted_noise = self.model(x_t, t, conditions, conditioning_signal)
# --- START: MODIFIED LOSS CALCULATION ---
loss = F.huber_loss(noise, predicted_noise, reduction='none')
if self.channel_weights is not None:
# Apply weights [B, L, C] * [1, 1, C]
weights = self.channel_weights.to(loss.device).view(1, 1, -1)
loss = (loss * weights).mean()
else:
loss = loss.mean()
return loss
# --- END: MODIFIED LOSS CALCULATION ---
@torch.no_grad()
def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple,
conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor:
device = next(self.model.parameters()).device
x = torch.randn(num_samples, *shape, device=device)
for t in tqdm(reversed(range(self.num_timesteps)), desc="Sampling", total=self.num_timesteps, leave=False):
t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
predicted_noise = self.model(x, t_batch, conditions, conditioning_signal)
alpha_t = self.alphas[t]
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]
mean = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
if t > 0:
noise = torch.randn_like(x)
variance = self.posterior_variance[t]
x = mean + torch.sqrt(variance) * noise
else:
x = mean
return x
class HierarchicalDiffusionModel(nn.Module):
def __init__(self, in_channels: int, num_houses: int, downscale_factor: int, channel_weights: Optional[torch.Tensor] = None, **model_kwargs):
super().__init__()
self.downscale_factor = downscale_factor
self.fine_chunk_size = 2 * 96
# Pop num_timesteps *only once* at the top
num_timesteps = model_kwargs.pop("num_timesteps")
self.downsampler = nn.Conv1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor)
self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor)
# Now num_timesteps can be passed to both models without error
self.coarse_model = ImprovedDiffusionModel(
ConditionalUnet(in_channels=in_channels, num_houses=num_houses, **model_kwargs),
num_timesteps,
channel_weights=channel_weights
)
self.fine_model = ImprovedDiffusionModel(
ConditionalUnet(in_channels=in_channels, num_houses=num_houses,
cond_channels=in_channels, **model_kwargs),
num_timesteps,
channel_weights=channel_weights
)
def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor]) -> torch.Tensor:
x_0_coarse = self.downsampler(x_0.permute(0, 2, 1)).permute(0, 2, 1)
coarse_loss = self.coarse_model(x_0_coarse, conditions)
with torch.no_grad():
x_0_coarse_upsampled = self.upsampler(x_0_coarse.detach().permute(0, 2, 1)).permute(0, 2, 1)
if x_0_coarse_upsampled.shape[1] != x_0.shape[1]:
diff = x_0.shape[1] - x_0_coarse_upsampled.shape[1]
if diff > 0: x_0_coarse_upsampled = F.pad(x_0_coarse_upsampled, [0, 0, 0, diff])
else: x_0_coarse_upsampled = x_0_coarse_upsampled[:, :x_0.shape[1], :]
x_0_fine_residual = x_0 - x_0_coarse_upsampled
full_length = x_0.shape[1]
if full_length > self.fine_chunk_size:
start_index = torch.randint(0, full_length - self.fine_chunk_size + 1, (1,)).item()
else:
start_index = 0
self.fine_chunk_size = full_length
residual_chunk = x_0_fine_residual[:, start_index:start_index + self.fine_chunk_size, :]
conditioning_chunk = x_0_coarse_upsampled[:, start_index:start_index + self.fine_chunk_size, :]
fine_loss = self.fine_model(residual_chunk, conditions, conditioning_signal=conditioning_chunk)
fine_loss_weight = 1.5
return coarse_loss + (fine_loss * fine_loss_weight)
@torch.no_grad()
def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple) -> torch.Tensor:
full_length, num_features = shape
device = next(self.parameters()).device
conditions = {k: v.to(device) for k, v in conditions.items()}
print("--- Stage 1: Sampling Coarse Structure ---")
coarse_shape = (full_length // self.downscale_factor, num_features)
generated_coarse = self.coarse_model.sample(num_samples, conditions, shape=coarse_shape)
upsampled_coarse = self.upsampler(generated_coarse.permute(0, 2, 1)).permute(0, 2, 1)
if upsampled_coarse.shape[1] != full_length:
diff = full_length - upsampled_coarse.shape[1]
if diff > 0: upsampled_coarse = F.pad(upsampled_coarse, [0, 0, 0, diff])
else: upsampled_coarse = upsampled_coarse[:, :full_length, :]
print("--- Stage 2: Sampling Fine Details ---")
stitched_fine_residual = torch.zeros_like(upsampled_coarse)
for start_index in tqdm(range(0, full_length, self.fine_chunk_size), desc="Fine chunks"):
end_index = min(start_index + self.fine_chunk_size, full_length)
chunk_length = end_index - start_index
fine_shape = (chunk_length, num_features)
conditioning_chunk = upsampled_coarse[:, start_index:end_index, :]
generated_fine_chunk = self.fine_model.sample(
num_samples, conditions, shape=fine_shape,
conditioning_signal=conditioning_chunk
)
stitched_fine_residual[:, start_index:end_index, :] = generated_fine_chunk
final_sample = upsampled_coarse + stitched_fine_residual
return final_sample