SolarSys2025 commited on
Commit
1082dd4
·
verified ·
1 Parent(s): e689010

Delete Data_generation_tool_kit/Hidiff_energy

Browse files
Data_generation_tool_kit/Hidiff_energy/__init__.py DELETED
File without changes
Data_generation_tool_kit/Hidiff_energy/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (157 Bytes)
 
Data_generation_tool_kit/Hidiff_energy/__pycache__/dataloader.cpython-312.pyc DELETED
Binary file (13.7 kB)
 
Data_generation_tool_kit/Hidiff_energy/__pycache__/hierarchial_diffusion_model.cpython-312.pyc DELETED
Binary file (26.1 kB)
 
Data_generation_tool_kit/Hidiff_energy/global_scaler.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b56b496f31ec90c863f0841bc8dd9c60d990662d93093a1dad427c012a721c2f
3
- size 477
 
 
 
 
Data_generation_tool_kit/Hidiff_energy/hierarchial_diffusion_model.py DELETED
@@ -1,384 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from typing import List, Optional, Dict
6
- from tqdm import tqdm
7
-
8
-
9
- class SinusoidalPositionEmbeddings(nn.Module):
10
- def __init__(self, dim: int):
11
- super().__init__()
12
- self.dim = dim
13
-
14
- def forward(self, time: torch.Tensor) -> torch.Tensor:
15
- device = time.device
16
- half_dim = self.dim // 2
17
- embeddings = math.log(10000) / (half_dim - 1)
18
- embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
19
- embeddings = time[:, None] * embeddings[None, :]
20
- embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
21
- return embeddings
22
-
23
-
24
- class ResnetBlock1D(nn.Module):
25
- def __init__(self, in_channels: int, out_channels: int, *, time_emb_dim: int = None, dropout: float = 0.1):
26
- super().__init__()
27
- self.time_mlp = nn.Sequential(
28
- nn.SiLU(),
29
- nn.Linear(time_emb_dim, out_channels * 2)
30
- ) if time_emb_dim is not None else None
31
-
32
- self.block1_conv = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
33
- self.block1_norm = nn.GroupNorm(8, out_channels, affine=False)
34
- self.block1_act = nn.SiLU()
35
-
36
- self.block2_conv = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
37
- self.block2_norm = nn.GroupNorm(8, out_channels)
38
- self.block2_act = nn.SiLU()
39
- self.block2_dropout = nn.Dropout(dropout)
40
-
41
- self.res_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
42
-
43
- def forward(self, x: torch.Tensor, time_emb: torch.Tensor = None) -> torch.Tensor:
44
- h = self.block1_conv(x)
45
- h = self.block1_norm(h)
46
-
47
- if self.time_mlp is not None and time_emb is not None:
48
- scale_shift = self.time_mlp(time_emb)
49
- scale, shift = scale_shift.chunk(2, dim=1)
50
- h = h * (scale.unsqueeze(-1) + 1) + shift.unsqueeze(-1)
51
-
52
- h = self.block1_act(h)
53
-
54
- h = self.block2_act(self.block2_norm(self.block2_conv(h)))
55
- h = self.block2_dropout(h)
56
- return h + self.res_conv(x)
57
-
58
-
59
- class AttentionBlock1D(nn.Module):
60
- def __init__(self, channels: int, num_heads: int = 8):
61
- super().__init__()
62
- self.channels = channels
63
- self.num_heads = num_heads
64
- assert channels % num_heads == 0, "channels must be divisible by num_heads"
65
- self.head_dim = channels // num_heads
66
-
67
- self.norm = nn.GroupNorm(8, channels)
68
- self.qkv = nn.Conv1d(channels, channels * 3, 1)
69
- self.proj = nn.Conv1d(channels, channels, 1)
70
-
71
- def forward(self, x: torch.Tensor) -> torch.Tensor:
72
- B, C, L = x.shape
73
- h = self.norm(x)
74
-
75
- qkv = self.qkv(h)
76
- qkv = qkv.view(B, 3, self.num_heads, self.head_dim, L)
77
- qkv = qkv.permute(1, 0, 2, 4, 3)
78
- q, k, v = qkv[0], qkv[1], qkv[2]
79
-
80
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
81
-
82
- out = out.permute(0, 1, 3, 2)
83
- out = out.contiguous().view(B, C, L)
84
-
85
- return x + self.proj(out)
86
-
87
-
88
- class DownBlock1D(nn.Module):
89
- def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2):
90
- super().__init__()
91
- self.resnets = nn.ModuleList([
92
- ResnetBlock1D(in_channels if i == 0 else out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout)
93
- for i in range(num_blocks)
94
- ])
95
- self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity()
96
- self.downsampler = nn.Conv1d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
97
-
98
- def forward(self, x, time_emb):
99
- for resnet in self.resnets:
100
- x = resnet(x, time_emb)
101
- x = self.attn(x)
102
- skip = x
103
- x = self.downsampler(x)
104
- return x, skip
105
-
106
-
107
- class UpBlock1D(nn.Module):
108
- def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2):
109
- super().__init__()
110
- self.resnets = nn.ModuleList()
111
- self.resnets.append(ResnetBlock1D(in_channels * 2, out_channels, time_emb_dim=time_emb_dim, dropout=dropout))
112
- for _ in range(num_blocks - 1):
113
- self.resnets.append(ResnetBlock1D(out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout))
114
- self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity()
115
- self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
116
-
117
- def forward(self, x, skip_x, time_emb):
118
- x = self.upsampler(x)
119
-
120
- if x.size(-1) != skip_x.size(-1):
121
- diff_L = skip_x.size(-1) - x.size(-1)
122
- if diff_L > 0:
123
- x = F.pad(x, [diff_L // 2, diff_L - diff_L // 2])
124
- elif diff_L < 0:
125
- x = x[:, :, :skip_x.size(-1)]
126
-
127
- x = torch.cat([skip_x, x], dim=1)
128
-
129
- for resnet in self.resnets:
130
- x = resnet(x, time_emb)
131
- return self.attn(x)
132
-
133
-
134
- class ConditionalUnet(nn.Module):
135
- def __init__(self, in_channels: int, num_houses: int, embedding_dim: int = 64,
136
- hidden_dims: List[int] = [64, 128, 256],
137
- dropout: float = 0.1, use_attention: bool = True,
138
- cond_channels: int = 0, blocks_per_level: int = 2):
139
- super().__init__()
140
- time_emb_dim = hidden_dims[0] * 4
141
-
142
- self.time_mlp = nn.Sequential(
143
- SinusoidalPositionEmbeddings(hidden_dims[0]),
144
- nn.Linear(hidden_dims[0], time_emb_dim),
145
- nn.SiLU(),
146
- nn.Linear(time_emb_dim, time_emb_dim)
147
- )
148
-
149
- self.house_embedding = nn.Embedding(num_houses, embedding_dim)
150
- self.house_proj = nn.Linear(embedding_dim, time_emb_dim)
151
-
152
- self.day_of_week_embedding = nn.Embedding(7, embedding_dim)
153
- self.day_of_year_embedding = nn.Embedding(366, embedding_dim)
154
-
155
- self.day_of_week_proj = nn.Linear(embedding_dim, time_emb_dim)
156
- self.day_of_year_proj = nn.Linear(embedding_dim, time_emb_dim)
157
-
158
- self.init_conv = nn.Conv1d(in_channels + cond_channels, hidden_dims[0], kernel_size=7, padding=3)
159
-
160
- num_resolutions = len(hidden_dims)
161
- self.down_blocks = nn.ModuleList([
162
- DownBlock1D(hidden_dims[i], hidden_dims[i+1], time_emb_dim, dropout, use_attention, blocks_per_level)
163
- for i in range(num_resolutions - 1)
164
- ])
165
-
166
- self.mid_block1 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout)
167
- self.mid_attn = AttentionBlock1D(hidden_dims[-1])
168
- self.mid_block2 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout)
169
-
170
- self.up_blocks = nn.ModuleList([
171
- UpBlock1D(hidden_dims[i+1], hidden_dims[i], time_emb_dim, dropout, use_attention, blocks_per_level)
172
- for i in reversed(range(num_resolutions - 1))
173
- ])
174
-
175
- self.final_conv = nn.Sequential(
176
- ResnetBlock1D(hidden_dims[0], hidden_dims[0], time_emb_dim=time_emb_dim, dropout=dropout),
177
- nn.Conv1d(hidden_dims[0], in_channels, 1)
178
- )
179
-
180
- def forward(self, x: torch.Tensor, timestep: torch.Tensor, conditions: Dict[str, torch.Tensor],
181
- conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor:
182
- time_emb = self.time_mlp(timestep)
183
-
184
- house_id = conditions["house_id"]
185
- day_of_week = conditions["day_of_week"]
186
- day_of_year = conditions["day_of_year"]
187
-
188
- house_emb = self.house_proj(self.house_embedding(house_id))
189
- dow_emb = self.day_of_week_proj(self.day_of_week_embedding(day_of_week))
190
- doy_emb = self.day_of_year_proj(self.day_of_year_embedding(day_of_year))
191
-
192
- emb = time_emb + house_emb + dow_emb + doy_emb
193
-
194
- x = x.permute(0, 2, 1)
195
- if conditioning_signal is not None:
196
- x = torch.cat([x, conditioning_signal.permute(0, 2, 1)], dim=1)
197
-
198
- x = self.init_conv(x)
199
-
200
- skip_connections = []
201
- for down_block in self.down_blocks:
202
- x, skip_x = down_block(x, emb)
203
- skip_connections.append(skip_x)
204
-
205
- x = self.mid_block1(x, emb)
206
- x = self.mid_attn(x)
207
- x = self.mid_block2(x, emb)
208
-
209
- for up_block in self.up_blocks:
210
- x = up_block(x, skip_connections.pop(), emb)
211
-
212
- return self.final_conv(x).permute(0, 2, 1)
213
-
214
-
215
- class ImprovedDiffusionModel(nn.Module):
216
- def __init__(self, base_model: ConditionalUnet, num_timesteps: int, channel_weights: torch.Tensor = None):
217
- super().__init__()
218
- self.model = base_model
219
- self.num_timesteps = num_timesteps
220
- self.channel_weights = channel_weights
221
-
222
- betas = self._cosine_beta_schedule(num_timesteps)
223
- alphas = 1.0 - betas
224
- alphas_cumprod = torch.cumprod(alphas, axis=0)
225
- alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
226
-
227
- self.register_buffer('betas', betas)
228
- self.register_buffer('alphas', alphas)
229
- self.register_buffer('alphas_cumprod', alphas_cumprod)
230
- self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
231
- self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
232
- self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
233
-
234
- posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
235
- posterior_variance = torch.clamp(posterior_variance, min=1e-20)
236
- self.register_buffer('posterior_variance', posterior_variance)
237
-
238
- def _cosine_beta_schedule(self, timesteps, s=0.008):
239
- steps = timesteps + 1
240
- x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
241
- alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
242
- alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
243
- betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
244
- return torch.clip(betas, 0.0001, 0.9999).float()
245
-
246
- def q_sample(self, x_start, t, noise=None):
247
- if noise is None: noise = torch.randn_like(x_start)
248
- sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
249
- sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
250
- return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
251
-
252
- def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor],
253
- conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor:
254
- t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=x_0.device).long()
255
- noise = torch.randn_like(x_0)
256
- x_t = self.q_sample(x_0, t, noise)
257
- predicted_noise = self.model(x_t, t, conditions, conditioning_signal)
258
-
259
- # --- START: MODIFIED LOSS CALCULATION ---
260
- loss = F.huber_loss(noise, predicted_noise, reduction='none')
261
-
262
- if self.channel_weights is not None:
263
- # Apply weights [B, L, C] * [1, 1, C]
264
- weights = self.channel_weights.to(loss.device).view(1, 1, -1)
265
- loss = (loss * weights).mean()
266
- else:
267
- loss = loss.mean()
268
-
269
- return loss
270
- # --- END: MODIFIED LOSS CALCULATION ---
271
-
272
- @torch.no_grad()
273
- def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple,
274
- conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor:
275
- device = next(self.model.parameters()).device
276
- x = torch.randn(num_samples, *shape, device=device)
277
-
278
- for t in tqdm(reversed(range(self.num_timesteps)), desc="Sampling", total=self.num_timesteps, leave=False):
279
- t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
280
- predicted_noise = self.model(x, t_batch, conditions, conditioning_signal)
281
-
282
- alpha_t = self.alphas[t]
283
- sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]
284
-
285
- mean = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
286
-
287
- if t > 0:
288
- noise = torch.randn_like(x)
289
- variance = self.posterior_variance[t]
290
- x = mean + torch.sqrt(variance) * noise
291
- else:
292
- x = mean
293
-
294
- return x
295
-
296
-
297
- class HierarchicalDiffusionModel(nn.Module):
298
- def __init__(self, in_channels: int, num_houses: int, downscale_factor: int, channel_weights: Optional[torch.Tensor] = None, **model_kwargs):
299
- super().__init__()
300
- self.downscale_factor = downscale_factor
301
- self.fine_chunk_size = 2 * 96
302
-
303
- # Pop num_timesteps *only once* at the top
304
- num_timesteps = model_kwargs.pop("num_timesteps")
305
-
306
- self.downsampler = nn.Conv1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor)
307
- self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor)
308
-
309
- # Now num_timesteps can be passed to both models without error
310
- self.coarse_model = ImprovedDiffusionModel(
311
- ConditionalUnet(in_channels=in_channels, num_houses=num_houses, **model_kwargs),
312
- num_timesteps,
313
- channel_weights=channel_weights
314
- )
315
- self.fine_model = ImprovedDiffusionModel(
316
- ConditionalUnet(in_channels=in_channels, num_houses=num_houses,
317
- cond_channels=in_channels, **model_kwargs),
318
- num_timesteps,
319
- channel_weights=channel_weights
320
- )
321
-
322
- def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor]) -> torch.Tensor:
323
- x_0_coarse = self.downsampler(x_0.permute(0, 2, 1)).permute(0, 2, 1)
324
- coarse_loss = self.coarse_model(x_0_coarse, conditions)
325
-
326
- with torch.no_grad():
327
- x_0_coarse_upsampled = self.upsampler(x_0_coarse.detach().permute(0, 2, 1)).permute(0, 2, 1)
328
-
329
- if x_0_coarse_upsampled.shape[1] != x_0.shape[1]:
330
- diff = x_0.shape[1] - x_0_coarse_upsampled.shape[1]
331
- if diff > 0: x_0_coarse_upsampled = F.pad(x_0_coarse_upsampled, [0, 0, 0, diff])
332
- else: x_0_coarse_upsampled = x_0_coarse_upsampled[:, :x_0.shape[1], :]
333
- x_0_fine_residual = x_0 - x_0_coarse_upsampled
334
-
335
- full_length = x_0.shape[1]
336
- if full_length > self.fine_chunk_size:
337
- start_index = torch.randint(0, full_length - self.fine_chunk_size + 1, (1,)).item()
338
- else:
339
- start_index = 0
340
- self.fine_chunk_size = full_length
341
-
342
- residual_chunk = x_0_fine_residual[:, start_index:start_index + self.fine_chunk_size, :]
343
- conditioning_chunk = x_0_coarse_upsampled[:, start_index:start_index + self.fine_chunk_size, :]
344
-
345
- fine_loss = self.fine_model(residual_chunk, conditions, conditioning_signal=conditioning_chunk)
346
-
347
- fine_loss_weight = 1.5
348
- return coarse_loss + (fine_loss * fine_loss_weight)
349
-
350
- @torch.no_grad()
351
- def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple) -> torch.Tensor:
352
- full_length, num_features = shape
353
- device = next(self.parameters()).device
354
-
355
- conditions = {k: v.to(device) for k, v in conditions.items()}
356
-
357
- print("--- Stage 1: Sampling Coarse Structure ---")
358
- coarse_shape = (full_length // self.downscale_factor, num_features)
359
- generated_coarse = self.coarse_model.sample(num_samples, conditions, shape=coarse_shape)
360
- upsampled_coarse = self.upsampler(generated_coarse.permute(0, 2, 1)).permute(0, 2, 1)
361
-
362
- if upsampled_coarse.shape[1] != full_length:
363
- diff = full_length - upsampled_coarse.shape[1]
364
- if diff > 0: upsampled_coarse = F.pad(upsampled_coarse, [0, 0, 0, diff])
365
- else: upsampled_coarse = upsampled_coarse[:, :full_length, :]
366
-
367
- print("--- Stage 2: Sampling Fine Details ---")
368
- stitched_fine_residual = torch.zeros_like(upsampled_coarse)
369
-
370
- for start_index in tqdm(range(0, full_length, self.fine_chunk_size), desc="Fine chunks"):
371
- end_index = min(start_index + self.fine_chunk_size, full_length)
372
- chunk_length = end_index - start_index
373
- fine_shape = (chunk_length, num_features)
374
- conditioning_chunk = upsampled_coarse[:, start_index:end_index, :]
375
-
376
- generated_fine_chunk = self.fine_model.sample(
377
- num_samples, conditions, shape=fine_shape,
378
- conditioning_signal=conditioning_chunk
379
- )
380
-
381
- stitched_fine_residual[:, start_index:end_index, :] = generated_fine_chunk
382
-
383
- final_sample = upsampled_coarse + stitched_fine_residual
384
- return final_sample