Delete Data_generation_tool_kit
Browse files
Data_generation_tool_kit/.DS_Store
DELETED
|
Binary file (6.15 kB)
|
|
|
Data_generation_tool_kit/dataloader.py
DELETED
|
@@ -1,179 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import numpy as np
|
| 4 |
-
from sklearn.preprocessing import MinMaxScaler
|
| 5 |
-
import joblib
|
| 6 |
-
import os
|
| 7 |
-
from typing import Tuple, Dict
|
| 8 |
-
import warnings
|
| 9 |
-
warnings.filterwarnings('ignore')
|
| 10 |
-
|
| 11 |
-
class MultiHouseDataset(torch.utils.data.Dataset):
|
| 12 |
-
|
| 13 |
-
def __init__(self, data_dir: str, window_size: int = 96, step_size: int = 1,
|
| 14 |
-
scaler_path: str = 'global_scaler.gz', cache_in_memory: bool = True,
|
| 15 |
-
dtype: torch.dtype = torch.float32, limit_to_one_year: bool = True):
|
| 16 |
-
self.window_size = window_size
|
| 17 |
-
self.step_size = step_size
|
| 18 |
-
self.cache_in_memory = cache_in_memory
|
| 19 |
-
self.dtype = dtype
|
| 20 |
-
self.limit_to_one_year = limit_to_one_year
|
| 21 |
-
|
| 22 |
-
all_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.csv')])
|
| 23 |
-
print(f"Found {len(all_files)} house files in '{data_dir}'.")
|
| 24 |
-
|
| 25 |
-
self.num_houses = len(all_files)
|
| 26 |
-
|
| 27 |
-
print("Reading house data...")
|
| 28 |
-
if self.limit_to_one_year:
|
| 29 |
-
print("INFO: Limiting data to the first year (17,520 samples) for each house.")
|
| 30 |
-
|
| 31 |
-
data_per_house = []
|
| 32 |
-
timestamps_per_house = []
|
| 33 |
-
|
| 34 |
-
SAMPLES_PER_YEAR = 17520
|
| 35 |
-
|
| 36 |
-
for filename in all_files:
|
| 37 |
-
df = pd.read_csv(os.path.join(data_dir, filename), parse_dates=['timestamp'])
|
| 38 |
-
timestamps_per_house.append(df['timestamp'].values)
|
| 39 |
-
time_series_values = df[['grid_usage', 'solar_generation']].values.astype(np.float32)
|
| 40 |
-
|
| 41 |
-
if self.limit_to_one_year:
|
| 42 |
-
time_series_values = time_series_values[:SAMPLES_PER_YEAR]
|
| 43 |
-
|
| 44 |
-
num_timesteps = len(time_series_values)
|
| 45 |
-
timesteps_of_day = np.arange(num_timesteps) % 48
|
| 46 |
-
|
| 47 |
-
sin_time = np.sin(2 * np.pi * timesteps_of_day / 48.0).astype(np.float32)
|
| 48 |
-
cos_time = np.cos(2 * np.pi * timesteps_of_day / 48.0).astype(np.float32)
|
| 49 |
-
|
| 50 |
-
time_series_values = np.concatenate([
|
| 51 |
-
time_series_values,
|
| 52 |
-
sin_time[:, np.newaxis],
|
| 53 |
-
cos_time[:, np.newaxis]
|
| 54 |
-
], axis=1)
|
| 55 |
-
|
| 56 |
-
data_per_house.append(time_series_values)
|
| 57 |
-
|
| 58 |
-
if os.path.exists(scaler_path):
|
| 59 |
-
scaler = joblib.load(scaler_path)
|
| 60 |
-
print(f"Scaler loaded from {scaler_path}")
|
| 61 |
-
else:
|
| 62 |
-
print("Fitting global scaler...")
|
| 63 |
-
combined_data = np.vstack(data_per_house)
|
| 64 |
-
scaler = MinMaxScaler(feature_range=(-1, 1))
|
| 65 |
-
scaler.fit(combined_data)
|
| 66 |
-
joblib.dump(scaler, scaler_path)
|
| 67 |
-
print(f"Scaler saved to {scaler_path}")
|
| 68 |
-
|
| 69 |
-
if self.cache_in_memory:
|
| 70 |
-
print("Caching normalized data...")
|
| 71 |
-
self.normalized_data_per_house = []
|
| 72 |
-
for series in data_per_house:
|
| 73 |
-
normalized = scaler.transform(series)
|
| 74 |
-
tensor_data = torch.from_numpy(normalized).to(dtype=self.dtype)
|
| 75 |
-
self.normalized_data_per_house.append(tensor_data)
|
| 76 |
-
else:
|
| 77 |
-
self.normalized_data_per_house = []
|
| 78 |
-
for series in data_per_house:
|
| 79 |
-
self.normalized_data_per_house.append(scaler.transform(series))
|
| 80 |
-
|
| 81 |
-
del data_per_house
|
| 82 |
-
|
| 83 |
-
print("Pre-computing mappings...")
|
| 84 |
-
|
| 85 |
-
self.windows_per_house = [(len(d) - self.window_size) // self.step_size + 1 for d in self.normalized_data_per_house]
|
| 86 |
-
self.cumulative_windows = np.cumsum([0] + self.windows_per_house)
|
| 87 |
-
self.total_windows = self.cumulative_windows[-1]
|
| 88 |
-
|
| 89 |
-
self.sample_to_house = np.empty(self.total_windows, dtype=np.int32)
|
| 90 |
-
self.sample_to_local_idx = np.empty(self.total_windows, dtype=np.int32)
|
| 91 |
-
self.sample_to_day_of_week = np.empty(self.total_windows, dtype=np.int32)
|
| 92 |
-
self.sample_to_day_of_year = np.empty(self.total_windows, dtype=np.int32)
|
| 93 |
-
|
| 94 |
-
for house_idx in range(self.num_houses):
|
| 95 |
-
start_global_idx = self.cumulative_windows[house_idx]
|
| 96 |
-
end_global_idx = self.cumulative_windows[house_idx + 1]
|
| 97 |
-
num_windows_for_this_house = self.windows_per_house[house_idx]
|
| 98 |
-
|
| 99 |
-
self.sample_to_house[start_global_idx:end_global_idx] = house_idx
|
| 100 |
-
|
| 101 |
-
local_indices = np.arange(num_windows_for_this_house) * self.step_size
|
| 102 |
-
self.sample_to_local_idx[start_global_idx:end_global_idx] = local_indices
|
| 103 |
-
|
| 104 |
-
house_timestamps = pd.Series(timestamps_per_house[house_idx][local_indices])
|
| 105 |
-
self.sample_to_day_of_week[start_global_idx:end_global_idx] = house_timestamps.dt.dayofweek
|
| 106 |
-
self.sample_to_day_of_year[start_global_idx:end_global_idx] = house_timestamps.dt.dayofyear - 1
|
| 107 |
-
|
| 108 |
-
print(f"Dataset initialized. Total windows: {self.total_windows} from {self.num_houses} houses.")
|
| 109 |
-
memory_usage = sum(data.numel() * data.element_size() for data in self.normalized_data_per_house) / 1e6 if self.cache_in_memory else 0
|
| 110 |
-
print(f"Memory usage for cached tensors: {memory_usage:.1f} MB")
|
| 111 |
-
|
| 112 |
-
def __len__(self) -> int:
|
| 113 |
-
return self.total_windows
|
| 114 |
-
|
| 115 |
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 116 |
-
if idx < 0 or idx >= self.total_windows:
|
| 117 |
-
raise IndexError("Index out of range")
|
| 118 |
-
|
| 119 |
-
house_index = self.sample_to_house[idx]
|
| 120 |
-
local_start_pos = self.sample_to_local_idx[idx]
|
| 121 |
-
|
| 122 |
-
window_data = self.normalized_data_per_house[house_index][local_start_pos : local_start_pos + self.window_size]
|
| 123 |
-
|
| 124 |
-
conditions = {
|
| 125 |
-
"house_id": torch.tensor(house_index, dtype=torch.long),
|
| 126 |
-
"day_of_week": torch.tensor(self.sample_to_day_of_week[idx], dtype=torch.long),
|
| 127 |
-
"day_of_year": torch.tensor(self.sample_to_day_of_year[idx], dtype=torch.long),
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
return window_data, conditions
|
| 131 |
-
|
| 132 |
-
def get_memory_usage(self) -> dict:
|
| 133 |
-
if self.cache_in_memory:
|
| 134 |
-
tensor_memory = sum(data.numel() * data.element_size() for data in self.normalized_data_per_house) / 1e6
|
| 135 |
-
else:
|
| 136 |
-
tensor_memory = 0
|
| 137 |
-
|
| 138 |
-
mapping_memory = (self.sample_to_house.nbytes + self.sample_to_local_idx.nbytes) / 1e6
|
| 139 |
-
|
| 140 |
-
return {
|
| 141 |
-
'tensor_cache_mb': tensor_memory,
|
| 142 |
-
'mapping_arrays_mb': mapping_memory,
|
| 143 |
-
'total_mb': tensor_memory + mapping_memory
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
class LatentDataset(torch.utils.data.Dataset):
|
| 147 |
-
def __init__(self, latent_vectors: torch.Tensor, house_ids: torch.Tensor):
|
| 148 |
-
assert len(latent_vectors) == len(house_ids), "Latent vectors and house IDs must have same length"
|
| 149 |
-
self.latent_vectors = latent_vectors.contiguous()
|
| 150 |
-
self.house_ids = house_ids.contiguous()
|
| 151 |
-
|
| 152 |
-
def __len__(self) -> int:
|
| 153 |
-
return len(self.latent_vectors)
|
| 154 |
-
|
| 155 |
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 156 |
-
return self.latent_vectors[idx], self.house_ids[idx]
|
| 157 |
-
|
| 158 |
-
if __name__ == "__main__":
|
| 159 |
-
import time
|
| 160 |
-
|
| 161 |
-
DATA_DIRECTORY = './data/per_house/'
|
| 162 |
-
|
| 163 |
-
if os.path.exists(DATA_DIRECTORY):
|
| 164 |
-
print("--- Testing Dataset Setup ---")
|
| 165 |
-
|
| 166 |
-
start_time = time.time()
|
| 167 |
-
dataset = MultiHouseDataset(data_dir=DATA_DIRECTORY, window_size=96, step_size=96)
|
| 168 |
-
init_time = time.time() - start_time
|
| 169 |
-
print(f"Dataset initialization: {init_time:.2f}s")
|
| 170 |
-
print(f"Memory usage: {dataset.get_memory_usage()}")
|
| 171 |
-
|
| 172 |
-
if len(dataset) > 0:
|
| 173 |
-
first_sample, first_conditions = dataset[0]
|
| 174 |
-
|
| 175 |
-
print(f"\nSample data shape: {first_sample.shape}")
|
| 176 |
-
print(f"Sample conditions: {first_conditions}")
|
| 177 |
-
print(f"Total houses: {dataset.num_houses}")
|
| 178 |
-
else:
|
| 179 |
-
print(f"ERROR: Data directory not found at '{DATA_DIRECTORY}'. Please create and populate this directory.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Data_generation_tool_kit/generate.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import os
|
| 6 |
-
import joblib
|
| 7 |
-
import math
|
| 8 |
-
import datetime
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
import matplotlib.pyplot as plt
|
| 11 |
-
import matplotlib.dates as mdates
|
| 12 |
-
|
| 13 |
-
# =============================================================================
|
| 14 |
-
# 1. MODEL CLASS DEFINITIONS
|
| 15 |
-
# =============================================================================
|
| 16 |
-
|
| 17 |
-
try:
|
| 18 |
-
from hierarchical_diffusion_model import (
|
| 19 |
-
HierarchicalDiffusionModel, ConditionalUnet, ResnetBlock1D,
|
| 20 |
-
AttentionBlock1D, DownBlock1D, UpBlock1D,
|
| 21 |
-
SinusoidalPositionEmbeddings, ImprovedDiffusionModel
|
| 22 |
-
)
|
| 23 |
-
print("Diffusion model classes imported.")
|
| 24 |
-
except ImportError:
|
| 25 |
-
print("="*50)
|
| 26 |
-
print("ERROR: Could not import model classes from 'hierarchical_diffusion_model.py'.")
|
| 27 |
-
print("="*50)
|
| 28 |
-
exit()
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# =============================================================================
|
| 32 |
-
# 2. HELPER FUNCTIONS
|
| 33 |
-
# =============================================================================
|
| 34 |
-
|
| 35 |
-
def add_amplitude_jitter(series, daily_samples=48, scale=0.05):
|
| 36 |
-
series = series.copy()
|
| 37 |
-
num_days = len(series) // daily_samples
|
| 38 |
-
if num_days == 0: return series
|
| 39 |
-
factors = np.random.normal(1.0, scale, size=num_days)
|
| 40 |
-
for d in range(num_days):
|
| 41 |
-
start, end = d * daily_samples, (d + 1) * daily_samples
|
| 42 |
-
series[start:end] *= factors[d]
|
| 43 |
-
return series
|
| 44 |
-
|
| 45 |
-
def add_cloud_variability(pv, timestamps, base_sigma=0.25):
|
| 46 |
-
pv = pv.copy()
|
| 47 |
-
if len(pv) == 0: return pv
|
| 48 |
-
days = pd.Series(pv, index=timestamps).groupby(timestamps.date)
|
| 49 |
-
adjusted = []
|
| 50 |
-
for day, vals in days:
|
| 51 |
-
cloud_factor = np.random.lognormal(mean=-0.02, sigma=base_sigma)
|
| 52 |
-
hour = vals.index.hour
|
| 53 |
-
day_pv = np.where((hour >= 6) & (hour <= 18), vals * cloud_factor, 0.0)
|
| 54 |
-
adjusted.append(day_pv)
|
| 55 |
-
if not adjusted: return np.array([])
|
| 56 |
-
return np.concatenate(adjusted)
|
| 57 |
-
|
| 58 |
-
def enforce_physics(df: pd.DataFrame, pv_cap_kw: float | None = None) -> pd.DataFrame:
|
| 59 |
-
df = df.copy()
|
| 60 |
-
df['solar_generation'] = np.clip(df['solar_generation'], 0.0, None)
|
| 61 |
-
hour = df.index.hour
|
| 62 |
-
night = (hour < 7) | (hour > 18)
|
| 63 |
-
df.loc[night, 'solar_generation'] = 0.0
|
| 64 |
-
export_mask = df['grid_usage'] < 0
|
| 65 |
-
if export_mask.any():
|
| 66 |
-
limited_export = -np.minimum(-df.loc[export_mask, 'grid_usage'], df.loc[export_mask, 'solar_generation'])
|
| 67 |
-
df.loc[export_mask, 'grid_usage'] = limited_export
|
| 68 |
-
zero_pv_neg_grid = export_mask & (df['solar_generation'] <= 1e-6)
|
| 69 |
-
df.loc[zero_pv_neg_grid, 'grid_usage'] = 0.0
|
| 70 |
-
if pv_cap_kw is not None:
|
| 71 |
-
df['solar_generation'] = np.clip(df['solar_generation'], 0.0, pv_cap_kw)
|
| 72 |
-
return df
|
| 73 |
-
|
| 74 |
-
def calculate_generation_length(duration: str, samples_per_day: int) -> int:
|
| 75 |
-
"""Calculate samples needed."""
|
| 76 |
-
if duration == '1_year':
|
| 77 |
-
return 365 * samples_per_day
|
| 78 |
-
elif duration == '6_months':
|
| 79 |
-
return 182 * samples_per_day
|
| 80 |
-
elif duration == '2_months':
|
| 81 |
-
return 60 * samples_per_day
|
| 82 |
-
elif duration == '1_month':
|
| 83 |
-
return 30 * samples_per_day
|
| 84 |
-
elif duration == '14_days':
|
| 85 |
-
return 14 * samples_per_day
|
| 86 |
-
elif duration == '7_days':
|
| 87 |
-
return 7 * samples_per_day
|
| 88 |
-
elif duration == '2_days':
|
| 89 |
-
return 2 * samples_per_day
|
| 90 |
-
else:
|
| 91 |
-
print(f"Warning: Unknown duration '{duration}'. Defaulting to 1 year.")
|
| 92 |
-
return 365 * samples_per_day
|
| 93 |
-
|
| 94 |
-
# =============================================================================
|
| 95 |
-
# 3. HARDCODED CONFIGURATION
|
| 96 |
-
# =============================================================================
|
| 97 |
-
|
| 98 |
-
class Config:
|
| 99 |
-
# --- Paths and Directories ---
|
| 100 |
-
MODEL_PATH = './trained_model/best_hierarchical_model.pth'
|
| 101 |
-
SCALER_PATH = './data/global_scaler.gz'
|
| 102 |
-
ORIGINAL_DATA_DIR = './data/per_house'
|
| 103 |
-
OUTPUT_DIR = './generated_data'
|
| 104 |
-
|
| 105 |
-
# --- Generation Parameters ---
|
| 106 |
-
GENERATION_DURATION = '1_year'
|
| 107 |
-
NUM_PROFILES_TO_GENERATE = 2000
|
| 108 |
-
PLOTS_TO_GENERATE = 20
|
| 109 |
-
GENERATION_BATCH_SIZE = 128
|
| 110 |
-
|
| 111 |
-
# --- Model & Training Parameters ---
|
| 112 |
-
TRAINING_WINDOW_DAYS = 14
|
| 113 |
-
|
| 114 |
-
NUM_HOUSES_TRAINED_ON = 300
|
| 115 |
-
SAMPLES_PER_DAY = 48
|
| 116 |
-
NUM_FEATURES = 4
|
| 117 |
-
DOWNSCALE_FACTOR = 4
|
| 118 |
-
EMBEDDING_DIM = 64
|
| 119 |
-
HIDDEN_SIZE = 512
|
| 120 |
-
HIDDEN_DIMS = [HIDDEN_SIZE // 4, HIDDEN_SIZE // 2, HIDDEN_SIZE]
|
| 121 |
-
DROPOUT = 0.1
|
| 122 |
-
USE_ATTENTION = True
|
| 123 |
-
DIFFUSION_TIMESTEPS = 500
|
| 124 |
-
BLOCKS_PER_LEVEL = 3
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# =============================================================================
|
| 128 |
-
# 4. MAIN GENERATION LOGIC
|
| 129 |
-
# =============================================================================
|
| 130 |
-
|
| 131 |
-
def main(cfg, run_output_dir):
|
| 132 |
-
"""Main generation logic."""
|
| 133 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 134 |
-
print(f"Using device: {DEVICE}")
|
| 135 |
-
|
| 136 |
-
csv_output_dir = os.path.join(run_output_dir, 'csv')
|
| 137 |
-
plot_output_dir = os.path.join(run_output_dir, 'plots')
|
| 138 |
-
os.makedirs(csv_output_dir, exist_ok=True)
|
| 139 |
-
os.makedirs(plot_output_dir, exist_ok=True)
|
| 140 |
-
|
| 141 |
-
print("Loading resources...")
|
| 142 |
-
try:
|
| 143 |
-
scaler = joblib.load(cfg.SCALER_PATH)
|
| 144 |
-
if scaler.n_features_in_ != cfg.NUM_FEATURES:
|
| 145 |
-
print(f"WARNING: Scaler was fit on {scaler.n_features_in_} features, but model expects {cfg.NUM_FEATURES}.")
|
| 146 |
-
|
| 147 |
-
original_files = sorted([f for f in os.listdir(cfg.ORIGINAL_DATA_DIR) if f.endswith('.csv')])
|
| 148 |
-
if not original_files:
|
| 149 |
-
raise FileNotFoundError("No original data files found to extract timestamps.")
|
| 150 |
-
|
| 151 |
-
sample_original_df = pd.read_csv(os.path.join(cfg.ORIGINAL_DATA_DIR, original_files[0]), index_col='timestamp', parse_dates=True)
|
| 152 |
-
|
| 153 |
-
# Load 1 year timestamps
|
| 154 |
-
full_timestamps = sample_original_df.index[:(365 * cfg.SAMPLES_PER_DAY)]
|
| 155 |
-
|
| 156 |
-
# Goal length
|
| 157 |
-
total_samples_needed = calculate_generation_length(cfg.GENERATION_DURATION, cfg.SAMPLES_PER_DAY)
|
| 158 |
-
|
| 159 |
-
# Training window length
|
| 160 |
-
TRAINING_WINDOW_SAMPLES = cfg.TRAINING_WINDOW_DAYS * cfg.SAMPLES_PER_DAY
|
| 161 |
-
|
| 162 |
-
# Clamping to max
|
| 163 |
-
if total_samples_needed > len(full_timestamps):
|
| 164 |
-
print(f"Warning: Requested {total_samples_needed} samples, but file has {len(full_timestamps)}. Clamping to max.")
|
| 165 |
-
total_samples_needed = len(full_timestamps)
|
| 166 |
-
|
| 167 |
-
print(f"Goal: Generate {total_samples_needed} samples ({cfg.GENERATION_DURATION}) per profile.")
|
| 168 |
-
print(f"Strategy: Stitching {TRAINING_WINDOW_SAMPLES}-sample chunks.")
|
| 169 |
-
|
| 170 |
-
model = HierarchicalDiffusionModel(
|
| 171 |
-
in_channels=cfg.NUM_FEATURES,
|
| 172 |
-
num_houses=cfg.NUM_HOUSES_TRAINED_ON,
|
| 173 |
-
downscale_factor=cfg.DOWNSCALE_FACTOR,
|
| 174 |
-
embedding_dim=cfg.EMBEDDING_DIM,
|
| 175 |
-
hidden_dims=cfg.HIDDEN_DIMS,
|
| 176 |
-
dropout=cfg.DROPOUT,
|
| 177 |
-
use_attention=cfg.USE_ATTENTION,
|
| 178 |
-
num_timesteps=cfg.DIFFUSION_TIMESTEPS,
|
| 179 |
-
blocks_per_level=cfg.BLOCKS_PER_LEVEL
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
model.load_state_dict(torch.load(cfg.MODEL_PATH, map_location=DEVICE))
|
| 183 |
-
model.to(DEVICE)
|
| 184 |
-
model.eval()
|
| 185 |
-
print("Model, scaler, timestamps ready.")
|
| 186 |
-
|
| 187 |
-
except FileNotFoundError as e:
|
| 188 |
-
print(f"ERROR: A required file was not found. Details: {e}")
|
| 189 |
-
return
|
| 190 |
-
except Exception as e:
|
| 191 |
-
print(f"An error occurred during setup: {e}")
|
| 192 |
-
return
|
| 193 |
-
|
| 194 |
-
num_batches = math.ceil(cfg.NUM_PROFILES_TO_GENERATE / cfg.GENERATION_BATCH_SIZE)
|
| 195 |
-
house_counter = 0
|
| 196 |
-
|
| 197 |
-
pbar = tqdm(range(num_batches), desc="Generating Batches")
|
| 198 |
-
for i in pbar:
|
| 199 |
-
current_batch_size = min(cfg.GENERATION_BATCH_SIZE, cfg.NUM_PROFILES_TO_GENERATE - house_counter)
|
| 200 |
-
if current_batch_size <= 0: break
|
| 201 |
-
pbar.set_postfix({'batch_size': current_batch_size})
|
| 202 |
-
|
| 203 |
-
# --- STITCHING LOGIC ---
|
| 204 |
-
num_chunks_needed = math.ceil(total_samples_needed / TRAINING_WINDOW_SAMPLES)
|
| 205 |
-
batch_chunks_list = []
|
| 206 |
-
|
| 207 |
-
for chunk_idx in range(num_chunks_needed):
|
| 208 |
-
# Calculate chunk length
|
| 209 |
-
samples_remaining = total_samples_needed - (chunk_idx * TRAINING_WINDOW_SAMPLES)
|
| 210 |
-
current_chunk_length = min(TRAINING_WINDOW_SAMPLES, samples_remaining)
|
| 211 |
-
|
| 212 |
-
shape_to_generate = (current_chunk_length, cfg.NUM_FEATURES)
|
| 213 |
-
|
| 214 |
-
# Generate random conditions
|
| 215 |
-
sample_conditions = {
|
| 216 |
-
"house_id": torch.randint(0, cfg.NUM_HOUSES_TRAINED_ON, (current_batch_size,), device=DEVICE),
|
| 217 |
-
"day_of_week": torch.randint(0, 7, (current_batch_size,), device=DEVICE),
|
| 218 |
-
"day_of_year": torch.randint(0, 365, (current_batch_size,), device=DEVICE)
|
| 219 |
-
}
|
| 220 |
-
|
| 221 |
-
with torch.no_grad():
|
| 222 |
-
# Generate one chunk
|
| 223 |
-
generated_chunk_data = model.sample(current_batch_size, sample_conditions, shape=shape_to_generate)
|
| 224 |
-
|
| 225 |
-
batch_chunks_list.append(generated_chunk_data.cpu().numpy())
|
| 226 |
-
|
| 227 |
-
# Stitch chunks together
|
| 228 |
-
generated_data_np = np.concatenate(batch_chunks_list, axis=1)
|
| 229 |
-
# --- END OF STITCHING LOGIC ---
|
| 230 |
-
|
| 231 |
-
# --- Post-processing loop ---
|
| 232 |
-
for j in range(current_batch_size):
|
| 233 |
-
current_house_num = house_counter + 1
|
| 234 |
-
# Select timestamps
|
| 235 |
-
profile_timestamps = full_timestamps[:total_samples_needed]
|
| 236 |
-
normalized_series = generated_data_np[j]
|
| 237 |
-
|
| 238 |
-
unscaled_series = scaler.inverse_transform(normalized_series)
|
| 239 |
-
|
| 240 |
-
df = pd.DataFrame(
|
| 241 |
-
unscaled_series,
|
| 242 |
-
columns=['grid_usage', 'solar_generation', 'sin_time', 'cos_time'],
|
| 243 |
-
index=profile_timestamps
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
df = enforce_physics(df)
|
| 247 |
-
df['grid_usage'] = add_amplitude_jitter(df['grid_usage'].values, scale=0.08, daily_samples=cfg.SAMPLES_PER_DAY)
|
| 248 |
-
df['solar_generation'] = add_cloud_variability(df['solar_generation'].values, df.index, base_sigma=0.3)
|
| 249 |
-
df = enforce_physics(df)
|
| 250 |
-
|
| 251 |
-
df_to_save = df[['grid_usage', 'solar_generation']]
|
| 252 |
-
df_to_save.to_csv(os.path.join(csv_output_dir, f'generated_house_{current_house_num}.csv'))
|
| 253 |
-
|
| 254 |
-
if house_counter < cfg.PLOTS_TO_GENERATE:
|
| 255 |
-
plot_df = df_to_save.head(cfg.SAMPLES_PER_DAY * 14)
|
| 256 |
-
plt.figure(figsize=(15, 6))
|
| 257 |
-
plt.plot(plot_df.index, plot_df['grid_usage'], label='Grid Usage', color='dodgerblue', alpha=0.9)
|
| 258 |
-
plt.plot(plot_df.index, plot_df['solar_generation'], label='Solar Generation', color='darkorange', alpha=0.9)
|
| 259 |
-
plt.title(f'Generated Data for Profile {current_house_num} (First 14 Days)')
|
| 260 |
-
plt.xlabel('Timestamp'); plt.ylabel('Power (kW)'); plt.legend(); plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
| 261 |
-
plt.tight_layout()
|
| 262 |
-
plt.savefig(os.path.join(plot_output_dir, f'generated_profile_{current_house_num}_plot.png'))
|
| 263 |
-
plt.close()
|
| 264 |
-
|
| 265 |
-
house_counter += 1
|
| 266 |
-
|
| 267 |
-
print(f"\nSuccessfully generated and saved {house_counter} house profiles.")
|
| 268 |
-
if cfg.PLOTS_TO_GENERATE > 0:
|
| 269 |
-
print(f"Plots saved to '{plot_output_dir}'.")
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
# =============================================================================
|
| 273 |
-
# 5. --- Main execution block ---
|
| 274 |
-
# =============================================================================
|
| 275 |
-
|
| 276 |
-
if __name__ == '__main__':
|
| 277 |
-
config = Config()
|
| 278 |
-
|
| 279 |
-
# Create unique output directory
|
| 280 |
-
run_timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 281 |
-
run_name = f"generation_run_{config.GENERATION_DURATION}_{run_timestamp}"
|
| 282 |
-
run_output_dir = os.path.join(config.OUTPUT_DIR, run_name)
|
| 283 |
-
os.makedirs(run_output_dir, exist_ok=True)
|
| 284 |
-
|
| 285 |
-
print(f"Starting new generation run: {run_name}")
|
| 286 |
-
print(f"All outputs will be saved to: {run_output_dir}")
|
| 287 |
-
|
| 288 |
-
# Run generation
|
| 289 |
-
main(config, run_output_dir)
|
| 290 |
-
|
| 291 |
-
print("\nGeneration process complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Data_generation_tool_kit/train.py
DELETED
|
@@ -1,209 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.utils.data import DataLoader, random_split, Subset
|
| 3 |
-
from torch.cuda.amp import autocast, GradScaler
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
import numpy as np
|
| 6 |
-
import os
|
| 7 |
-
import datetime
|
| 8 |
-
import pandas as pd
|
| 9 |
-
import matplotlib.pyplot as plt
|
| 10 |
-
import math
|
| 11 |
-
import joblib
|
| 12 |
-
|
| 13 |
-
from dataloader import MultiHouseDataset
|
| 14 |
-
from hierarchical_diffusion_model import HierarchicalDiffusionModel
|
| 15 |
-
|
| 16 |
-
if torch.cuda.is_available():
|
| 17 |
-
DEVICE = "cuda"
|
| 18 |
-
torch.backends.cudnn.benchmark = True
|
| 19 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 20 |
-
print("Using NVIDIA CUDA backend.")
|
| 21 |
-
elif torch.backends.mps.is_available():
|
| 22 |
-
DEVICE = "mps"
|
| 23 |
-
print("Using Apple MPS backend.")
|
| 24 |
-
else:
|
| 25 |
-
DEVICE = "cpu"
|
| 26 |
-
print("Using CPU.")
|
| 27 |
-
|
| 28 |
-
EPOCHS = 200
|
| 29 |
-
LEARNING_RATE = 1e-4
|
| 30 |
-
BATCH_SIZE = 512
|
| 31 |
-
USE_AMP = True
|
| 32 |
-
GRADIENT_CLIP_VAL = 0.1
|
| 33 |
-
|
| 34 |
-
WINDOW_DURATION = '14_days'
|
| 35 |
-
|
| 36 |
-
DATA_DIRECTORY = './data/per_house'
|
| 37 |
-
NUM_WORKERS = os.cpu_count() // 2
|
| 38 |
-
PIN_MEMORY = True
|
| 39 |
-
USE_ATTENTION = True
|
| 40 |
-
DROPOUT = 0.1
|
| 41 |
-
HIDDEN_SIZE = 512
|
| 42 |
-
EMBEDDING_DIM = 64
|
| 43 |
-
DIFFUSION_TIMESTEPS = 500
|
| 44 |
-
DOWNSCALE_FACTOR = 4
|
| 45 |
-
|
| 46 |
-
def calculate_window_size(duration: str) -> int:
|
| 47 |
-
SAMPLES_PER_DAY = 48
|
| 48 |
-
mapping = {
|
| 49 |
-
'2_days': 2 * SAMPLES_PER_DAY,
|
| 50 |
-
'7_days': 7 * SAMPLES_PER_DAY,
|
| 51 |
-
'14_days': 14 * SAMPLES_PER_DAY,
|
| 52 |
-
'15_days': 15 * SAMPLES_PER_DAY,
|
| 53 |
-
'30_days': 30 * SAMPLES_PER_DAY
|
| 54 |
-
}
|
| 55 |
-
if duration not in mapping:
|
| 56 |
-
raise ValueError(f"Invalid WINDOW_DURATION: {duration}")
|
| 57 |
-
return mapping[duration]
|
| 58 |
-
|
| 59 |
-
def denormalize_data(normalized_data, scaler_path='global_scaler.gz'):
|
| 60 |
-
scaler = joblib.load(scaler_path)
|
| 61 |
-
original_shape = normalized_data.shape
|
| 62 |
-
if len(original_shape) == 3:
|
| 63 |
-
batch_size, seq_len, features = original_shape
|
| 64 |
-
normalized_flat = normalized_data.reshape(-1, features)
|
| 65 |
-
denormalized_flat = scaler.inverse_transform(normalized_flat)
|
| 66 |
-
return denormalized_flat.reshape(original_shape)
|
| 67 |
-
else:
|
| 68 |
-
return scaler.inverse_transform(normalized_data)
|
| 69 |
-
|
| 70 |
-
def moving_average(data, window_size):
|
| 71 |
-
return np.convolve(data, np.ones(window_size), 'valid') / window_size
|
| 72 |
-
|
| 73 |
-
def save_and_plot_loss(loss_dict, title, filepath, window_size=10):
|
| 74 |
-
plt.figure(figsize=(12, 6))
|
| 75 |
-
for label, losses in loss_dict.items():
|
| 76 |
-
pd.DataFrame({label: losses}).to_csv(f"{filepath}_{label.lower().replace(' ', '_')}.csv", index=False)
|
| 77 |
-
plt.plot(losses, label=f'Raw {label}', alpha=0.3)
|
| 78 |
-
if len(losses) > window_size:
|
| 79 |
-
smoothed_losses = moving_average(losses, window_size)
|
| 80 |
-
plt.plot(np.arange(window_size - 1, len(losses)), smoothed_losses, label=f'Smoothed {label}')
|
| 81 |
-
plt.title(title)
|
| 82 |
-
plt.xlabel('Epoch'); plt.ylabel('Loss')
|
| 83 |
-
plt.legend(); plt.grid(True)
|
| 84 |
-
plt.savefig(f"{filepath}.png"); plt.close()
|
| 85 |
-
print(f" Loss plot saved to {filepath}.png")
|
| 86 |
-
|
| 87 |
-
def train_diffusion(log_dir, model_save_path):
|
| 88 |
-
print("--- Starting Hierarchical Diffusion Training ---")
|
| 89 |
-
window_size = calculate_window_size(WINDOW_DURATION)
|
| 90 |
-
print(f"Using window duration: {WINDOW_DURATION} ({window_size} samples)")
|
| 91 |
-
|
| 92 |
-
dataset = MultiHouseDataset(
|
| 93 |
-
data_dir=DATA_DIRECTORY,
|
| 94 |
-
window_size=window_size,
|
| 95 |
-
step_size=window_size//2,
|
| 96 |
-
limit_to_one_year=False
|
| 97 |
-
)
|
| 98 |
-
print(f"Dataset loaded: {len(dataset)} samples, {dataset.num_houses} houses, {dataset[0][0].shape[1]} features.")
|
| 99 |
-
|
| 100 |
-
val_split = 0.1
|
| 101 |
-
val_size = int(len(dataset) * val_split)
|
| 102 |
-
train_size = len(dataset) - val_size
|
| 103 |
-
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
| 104 |
-
print(f"Train size: {train_size}, Validation size: {val_size}")
|
| 105 |
-
|
| 106 |
-
train_dataloader = DataLoader(
|
| 107 |
-
train_dataset, batch_size=BATCH_SIZE, shuffle=True,
|
| 108 |
-
num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
|
| 109 |
-
)
|
| 110 |
-
val_dataloader = DataLoader(
|
| 111 |
-
val_dataset, batch_size=BATCH_SIZE*2, shuffle=False,
|
| 112 |
-
num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
channel_weights = torch.tensor([1.0, 8.0, 1.0, 1.0], device=DEVICE)
|
| 116 |
-
print(f"Using channel weights: {channel_weights}")
|
| 117 |
-
|
| 118 |
-
model = HierarchicalDiffusionModel(
|
| 119 |
-
in_channels=dataset[0][0].shape[1],
|
| 120 |
-
num_houses=dataset.num_houses,
|
| 121 |
-
downscale_factor=DOWNSCALE_FACTOR,
|
| 122 |
-
channel_weights=channel_weights,
|
| 123 |
-
embedding_dim=EMBEDDING_DIM,
|
| 124 |
-
hidden_dims=[HIDDEN_SIZE // 4, HIDDEN_SIZE // 2, HIDDEN_SIZE],
|
| 125 |
-
dropout=DROPOUT,
|
| 126 |
-
use_attention=USE_ATTENTION,
|
| 127 |
-
num_timesteps=DIFFUSION_TIMESTEPS,
|
| 128 |
-
blocks_per_level=3
|
| 129 |
-
).to(DEVICE)
|
| 130 |
-
|
| 131 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
| 132 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 133 |
-
scaler = GradScaler(enabled=(USE_AMP and DEVICE == "cuda"))
|
| 134 |
-
|
| 135 |
-
train_losses, val_losses = [], []
|
| 136 |
-
best_val_loss = float('inf')
|
| 137 |
-
|
| 138 |
-
print(f"Starting training for {EPOCHS} epochs...")
|
| 139 |
-
for epoch in range(EPOCHS):
|
| 140 |
-
model.train()
|
| 141 |
-
total_train_loss = 0.0
|
| 142 |
-
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} (Train)")
|
| 143 |
-
|
| 144 |
-
for clean_data, conditions in pbar:
|
| 145 |
-
clean_data = clean_data.to(DEVICE, non_blocking=PIN_MEMORY)
|
| 146 |
-
conditions = {k: v.to(DEVICE, non_blocking=PIN_MEMORY) for k, v in conditions.items()}
|
| 147 |
-
|
| 148 |
-
optimizer.zero_grad(set_to_none=True)
|
| 149 |
-
with autocast(enabled=(USE_AMP and DEVICE == "cuda")):
|
| 150 |
-
loss = model(clean_data, conditions)
|
| 151 |
-
|
| 152 |
-
scaler.scale(loss).backward()
|
| 153 |
-
scaler.unscale_(optimizer)
|
| 154 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
|
| 155 |
-
scaler.step(optimizer)
|
| 156 |
-
scaler.update()
|
| 157 |
-
|
| 158 |
-
total_train_loss += loss.item()
|
| 159 |
-
pbar.set_postfix({'loss': f'{loss.item():.6f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})
|
| 160 |
-
|
| 161 |
-
avg_train_loss = total_train_loss / len(train_dataloader)
|
| 162 |
-
train_losses.append(avg_train_loss)
|
| 163 |
-
|
| 164 |
-
model.eval()
|
| 165 |
-
total_val_loss = 0.0
|
| 166 |
-
with torch.no_grad():
|
| 167 |
-
for clean_data, conditions in tqdm(val_dataloader, desc="Validating"):
|
| 168 |
-
clean_data = clean_data.to(DEVICE, non_blocking=PIN_MEMORY)
|
| 169 |
-
conditions = {k: v.to(DEVICE, non_blocking=PIN_MEMORY) for k, v in conditions.items()}
|
| 170 |
-
with autocast(enabled=(USE_AMP and DEVICE == "cuda")):
|
| 171 |
-
loss = model(clean_data, conditions)
|
| 172 |
-
total_val_loss += loss.item()
|
| 173 |
-
|
| 174 |
-
avg_val_loss = total_val_loss / len(val_dataloader)
|
| 175 |
-
val_losses.append(avg_val_loss)
|
| 176 |
-
|
| 177 |
-
print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
|
| 178 |
-
|
| 179 |
-
if avg_val_loss < best_val_loss:
|
| 180 |
-
best_val_loss = avg_val_loss
|
| 181 |
-
torch.save(model.state_dict(), model_save_path)
|
| 182 |
-
print(f"New best model saved to {model_save_path} (Val Loss: {best_val_loss:.6f})")
|
| 183 |
-
|
| 184 |
-
scheduler.step()
|
| 185 |
-
|
| 186 |
-
print("--- Training complete ---")
|
| 187 |
-
save_and_plot_loss(
|
| 188 |
-
{'Train Loss': train_losses, 'Validation Loss': val_losses},
|
| 189 |
-
'Hierarchical Diffusion Model Training & Validation Loss',
|
| 190 |
-
os.path.join(log_dir, 'diffusion_loss_curves')
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
return dataset
|
| 194 |
-
|
| 195 |
-
if __name__ == "__main__":
|
| 196 |
-
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 197 |
-
run_name = f"hierarchical_diffusion_{WINDOW_DURATION}_{timestamp}"
|
| 198 |
-
log_dir = os.path.join("./training_logs", run_name)
|
| 199 |
-
os.makedirs(log_dir, exist_ok=True)
|
| 200 |
-
model_path = os.path.join(log_dir, 'best_hierarchical_model.pth')
|
| 201 |
-
|
| 202 |
-
print(f"Starting new run: {run_name}")
|
| 203 |
-
print(f"Logs and models will be saved to: {log_dir}")
|
| 204 |
-
|
| 205 |
-
full_dataset = train_diffusion(log_dir=log_dir, model_save_path=model_path)
|
| 206 |
-
|
| 207 |
-
print("\nTraining and best model saving complete.")
|
| 208 |
-
print(f"Model saved to: {model_path}")
|
| 209 |
-
print(f"Loss curves saved to: {os.path.join(log_dir, 'diffusion_loss_curves.png')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|