Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| os.environ["WANDB_ENABLED"] = "false" | |
| from engine.solver import Trainer | |
| from data.build_dataloader import build_dataloader | |
| from utils.metric_utils import visualization, save_pdf | |
| # from utils.metric_utils import visualization | |
| from utils.io_utils import load_yaml_config, instantiate_from_config | |
| from models.model_utils import unnormalize_to_zero_to_one | |
| from scipy.signal import find_peaks, peak_prominences | |
| # disable user warnings | |
| import warnings | |
| warnings.simplefilter("ignore", UserWarning) | |
| import scipy.stats | |
| import numpy as np | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| from sklearn.manifold import TSNE | |
| from sklearn.decomposition import PCA | |
| class Arguments: | |
| def __init__(self, config_path) -> None: | |
| self.config_path = config_path | |
| # self.config_path = "./config/control/revenue-baseline-sine.yaml" | |
| self.save_dir = ( | |
| "../../../data/" + os.path.basename(self.config_path).split(".")[0] | |
| ) | |
| self.gpu = 0 | |
| os.makedirs(self.save_dir, exist_ok=True) | |
| self.mode = "infill" | |
| self.missing_ratio = 0.95 | |
| self.milestone = 10 | |
| import numpy as np | |
| import matplotlib as mpl | |
| def create_color_gradient(sorting_value=None, start_color='#FFFF00', end_color='#00008B'): | |
| """Create color gradient using matplotlib color interpolation.""" | |
| def color_fader(c1, c2, mix=0): | |
| """Fade from color c1 to c2 with mix ratio.""" | |
| c1 = np.array(mpl.colors.to_rgb(c1)) | |
| c2 = np.array(mpl.colors.to_rgb(c2)) | |
| return mpl.colors.to_hex((1-mix)*c1 + mix*c2) | |
| if sorting_value is not None: | |
| # Normalize values between 0-1 | |
| values = np.array(list(sorting_value.values())) | |
| normalized = (values - values.min()) / (values.max() - values.min()) | |
| # Create color mapping | |
| return { | |
| key: color_fader(start_color, end_color, mix=norm_val) | |
| for key, norm_val in zip(sorting_value.keys(), normalized) | |
| } | |
| else: | |
| # Return middle point color | |
| return color_fader(start_color, end_color, mix=0.5) | |
| def create_color_gradient(sorting_value=None, start_color='#FFFF00', middle_color='#00FF00', end_color='#00008B'): | |
| """Create color gradient using matplotlib interpolation with middle color.""" | |
| def color_fader(c1, c2, mix=0): | |
| """Fade from color c1 to c2 with mix ratio.""" | |
| c1 = np.array(mpl.colors.to_rgb(c1)) | |
| c2 = np.array(mpl.colors.to_rgb(c2)) | |
| return mpl.colors.to_hex((1-mix)*c1 + mix*c2) | |
| if sorting_value is not None: | |
| values = np.array(list(sorting_value.values())) | |
| normalized = (values - values.min()) / (values.max() - values.min()) | |
| colors = {} | |
| for key, norm_val in zip(sorting_value.keys(), normalized): | |
| if norm_val <= 0.5: | |
| # Interpolate between start and middle | |
| mix = norm_val * 2 # Scale 0-0.5 to 0-1 | |
| colors[key] = color_fader(start_color, middle_color, mix) | |
| else: | |
| # Interpolate between middle and end | |
| mix = (norm_val - 0.5) * 2 # Scale 0.5-1 to 0-1 | |
| colors[key] = color_fader(middle_color, end_color, mix) | |
| return colors | |
| else: | |
| return middle_color # Return middle color directly | |
| def evaluate_peak_detection(data, target_peaks, window_size=7, min_distance=5, prominence_threshold=0.1): | |
| """ | |
| Evaluate peak detection accuracy by comparing detected peaks with target peaks. | |
| Parameters: | |
| data: numpy array of shape (batch_size, seq_length, features) | |
| The generated sequences to analyze | |
| The indices where peaks should occur (e.g., every 7 steps for weekly peaks) | |
| target_peak: list | |
| List of indices where peaks should occur | |
| window_size: int | |
| Size of window to consider a peak match | |
| """ | |
| batch_size, seq_length, features = data.shape | |
| detected_peaks = [] | |
| accuracy_metrics = {} | |
| # Create figure for visualization | |
| fig, axes = plt.subplots(4, 2, figsize=(20, 12)) | |
| axes = axes.flatten() | |
| # Analyze first 8 batches and first feature (revenue) | |
| overall_matched = 0 | |
| overall_targets = 0 | |
| for i in range(8): | |
| sequence = data[i, :, 0] # batch i, all timepoints, revenue feature | |
| # Find peaks using scipy | |
| peaks, properties = find_peaks(sequence, | |
| distance=min_distance, | |
| prominence=prominence_threshold) | |
| # Plot original sequence and detected peaks | |
| axes[i].plot(sequence, label='Generated Sequence') | |
| axes[i].plot(peaks, sequence[peaks], "x", label='Detected Peaks') | |
| # Plot target peak positions | |
| target_positions = target_peaks # np.arange(0, seq_length, 7) # Weekly peaks | |
| axes[i].plot(target_positions, sequence[target_positions], "o", | |
| label='Target Peak Positions') | |
| axes[i].set_title(f'Sequence {i+1} Peak Detection Analysis') | |
| axes[i].legend() | |
| axes[i].grid(True) | |
| # Count matches within window for this sequence | |
| matched_peaks = 0 | |
| for target in target_positions: | |
| # Check if any detected peak is within the window of the target | |
| matches = np.any((peaks >= target - window_size//2) & | |
| (peaks <= target + window_size//2)) | |
| if matches: | |
| matched_peaks += 1 | |
| overall_matched += matched_peaks | |
| overall_targets += len(target_positions) | |
| for i in range(8, batch_size): | |
| peaks, properties = find_peaks(data[i, :, 0], distance=min_distance, prominence=prominence_threshold) | |
| matched_peaks = 0 | |
| for target in target_peaks: | |
| matches = np.any((peaks >= target - window_size//2) & | |
| (peaks <= target + window_size//2)) | |
| if matches: | |
| matched_peaks += 1 | |
| overall_matched += matched_peaks | |
| overall_targets += len(target_peaks) | |
| # Calculate overall metrics | |
| accuracy = overall_matched / overall_targets | |
| precision = overall_matched / (len(peaks) * 8) if len(peaks) > 0 else 0 | |
| accuracy_metrics = { | |
| 'accuracy': accuracy, | |
| 'precision': precision, | |
| 'total_targets': overall_targets, | |
| 'detected_peaks': len(peaks) * 8, | |
| 'matched_peaks': overall_matched | |
| } | |
| plt.tight_layout() | |
| plt.show() | |
| return accuracy_metrics, peaks | |
| for config_path in [ | |
| "./config/modified/sines.yaml", | |
| "./config/modified/revenue-baseline-365.yaml", | |
| "./config/modified/energy.yaml", | |
| "./config/modified/fmri.yaml", | |
| ]: | |
| args = Arguments(config_path) | |
| configs = load_yaml_config(args.config_path) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.set_device(args.gpu) | |
| dl_info = build_dataloader(configs, args) | |
| model = instantiate_from_config(configs["model"]).to(device) | |
| trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) | |
| # trainer.load(args.milestone, from_folder="../../../data/ckpt_baseline_240") | |
| # trainer.train() | |
| from data.build_dataloader import build_dataloader_cond | |
| # args.milestone | |
| trainer.load("10") | |
| test_dl_info = build_dataloader_cond(configs, args) | |
| test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] | |
| coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
| stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
| sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] | |
| seq_length, feature_dim = test_dataset.window, test_dataset.var_num | |
| # samples, ori_data, masks = trainer.restore( | |
| # test_dataloader, | |
| # [seq_length, feature_dim], | |
| # coef, | |
| # stepsize, | |
| # sampling_steps, | |
| # control_signal={}, | |
| # # test= | |
| # ) | |
| # if test_dataset.auto_norm: | |
| # samples = unnormalize_to_zero_to_one(samples) | |
| # ori_data = np.load(os.path.join(dataset.dir, f"sine_ground_truth_{seq_length}_test.npy")) | |
| dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] | |
| mapper = { | |
| "sines": "sines", | |
| "revenue": "revenue", | |
| "energy": "energy", | |
| "fmri": "fMRI", | |
| } | |
| gap = seq_length // 5 | |
| ori_data = np.load( | |
| os.path.join("../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy") | |
| ) | |
| masks = np.load(os.path.join("../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_masking_{seq_length}.npy")) | |
| sample_num, seq_len, feat_dim = masks.shape | |
| observed = ori_data[:sample_num] * masks | |
| ori_data = ori_data[:sample_num] | |
| import pickle | |
| from pathlib import Path | |
| # Cache file path | |
| cache_dir = Path(f"../../../data/cache_{dataset_name}") | |
| cache_dir.mkdir(exist_ok=True) | |
| def load_cached_results(): | |
| results = {'unconditional': None, 'sum_controlled': {}, 'anchor_controlled': {}} | |
| for cache_file in cache_dir.glob('*.pkl'): | |
| with open(cache_file, 'rb') as f: | |
| key = cache_file.stem | |
| if key == 'unconditional': | |
| results['unconditional'] = pickle.load(f) | |
| elif key.startswith('sum_'): | |
| param = key[4:] # Remove 'sum_' prefix | |
| results['sum_controlled'][param] = pickle.load(f) | |
| elif key.startswith('anchor_'): | |
| param = key[7:] # Remove 'anchor_' prefix | |
| results['anchor_controlled'][param] = pickle.load(f) | |
| return results | |
| def save_result(key, subkey, data): | |
| if subkey: | |
| filename = f"{key}_{subkey}.pkl" | |
| else: | |
| filename = f"{key}.pkl" | |
| with open(cache_dir / filename, 'wb') as f: | |
| pickle.dump(data, f) | |
| results = load_cached_results() | |
| dataset = dl_info["dataset"] | |
| seq_length, feature_dim = dataset.window, dataset.var_num | |
| coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
| stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
| # Unconditional sampling | |
| if results['unconditional'] is None: | |
| print("Generating unconditional data...") | |
| results['unconditional'] = trainer.sample( | |
| num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim] | |
| ) | |
| save_result('unconditional', None, results['unconditional']) | |
| # Different AUC weights | |
| auc_weights = [10,] | |
| auc_values = [-200, -150, -100, 0, 20, 30, 50, 100, 150] | |
| for auc in auc_values: | |
| for weight in auc_weights: | |
| key = f"auc_{auc}_weight_{weight}" | |
| if key not in results['sum_controlled']: | |
| print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
| results['sum_controlled'][key] = trainer.control_sample( | |
| num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
| "coef": coef, | |
| "learning_rate": stepsize | |
| } | |
| ) | |
| save_result('sum', key, results['sum_controlled'][key]) | |
| auc_weights = [1, 10, 50, 100] | |
| auc_values = [-200,] | |
| for auc in auc_values: | |
| for weight in auc_weights: | |
| key = f"auc_{auc}_weight_{weight}" | |
| if key not in results['sum_controlled']: | |
| print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
| results['sum_controlled'][key] = trainer.control_sample( | |
| num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
| "coef": coef, | |
| "learning_rate": stepsize | |
| } | |
| ) | |
| save_result('sum', key, results['sum_controlled'][key]) | |
| # Different weekly peaks | |
| peak_values = [0.8, 1.0] | |
| peak_weights = [0.1, 0.5, 1.0] | |
| # import matplotlib.pyplot as plt | |
| # for peak in peak_values: | |
| # for weight in peak_weights: | |
| # key = f"peak_{peak}_weight_{weight}" | |
| # if key not in results['anchor_controlled']: | |
| # mask = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
| # mask[::gap, 0] = weight | |
| # target = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
| # target[::gap, 0] = peak | |
| # print(f"Generating anchor controlled data - Peak: {peak}, Weight: {weight}") | |
| # results['anchor_controlled'][key] = trainer.control_sample( | |
| # num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], | |
| # model_kwargs={ | |
| # "gradient_control_signal": {"auc": -50, "auc_weight": 10.0}, | |
| # "coef": coef, | |
| # "learning_rate": stepsize | |
| # }, | |
| # target=target, | |
| # partial_mask=mask | |
| # ) | |
| # save_result('anchor', key, results['anchor_controlled'][key]) | |
| # # plot mask, target, and generated sequence | |
| # plt.figure(figsize=(12, 6)) | |
| # plt.plot(mask[:, 0], label='Mask') | |
| # plt.plot(target[:, 0], label='Target') | |
| # plt.plot(results['anchor_controlled'][key][0, :, 0], label='Generated Sequence') | |
| # plt.title(f"Anchor Controlled Data - Peak: {peak}, Weight: {weight}") | |
| # plt.legend() | |
| # plt.show() | |
| # Unnormalize results if needed | |
| if dataset.auto_norm: | |
| for key, data in results.items(): | |
| if isinstance(data, dict): | |
| for subkey, subdata in data.items(): | |
| results[key][subkey] = unnormalize_to_zero_to_one(subdata) | |
| else: | |
| results[key] = unnormalize_to_zero_to_one(data) | |
| # Store the results in variables for compatibility with existing code | |
| unconditional_data = results['unconditional'] | |
| sum_controled_data = results['sum_controlled']# ['auc_0_weight_10.0'] # default values | |
| anchor_controled_data = results['anchor_controlled'] # ['peak_0.8_weight_0.1'] # default values | |
| # Sum control | |
| samples = 1000 | |
| data = { | |
| "ori_data": ori_data[:samples, :, :1], | |
| "Unconditional": unconditional_data[:samples, :, :1], | |
| } | |
| # for key, value in sum_controled_data.items(): | |
| # if "weight_10" in key: | |
| # data[key] = value | |
| # print(key) | |
| keys = [ | |
| # "auc_-200_weight_10", | |
| "auc_-100_weight_10", | |
| # "auc_0_weight_10", | |
| "auc_20_weight_10", | |
| # "auc_30_weight_10", | |
| "auc_50_weight_10", | |
| # "auc_100_weight_10", | |
| "auc_150_weight_10", | |
| ] | |
| for key in keys: | |
| data[key] = sum_controled_data[key][:samples, :, :1] | |
| # print sum | |
| print(key, " ==> ", sum_controled_data[key][:samples, :, :1].sum() / sum_controled_data[key][:samples, :, :1].shape[0]) | |
| # visualization_control( | |
| # data=data, | |
| # analysis="kernel", | |
| # compare=ori_data.shape[0], | |
| # output_label="revenue" | |
| # ) | |
| def visualization_control_subplots(data, analysis="kernel", compare=100, output_label="", highlight=None): | |
| # from scipy import integrate | |
| # Calculate area under curve for each distribution | |
| def get_auc(data_array): | |
| return data_array.sum(-1).mean() | |
| # Get AUC values | |
| auc_orig = get_auc(data["ori_data"]) | |
| auc_uncond = get_auc(data["Unconditional"]) | |
| # Setup subplots | |
| keys = [k for k in data.keys() if k not in ["ori_data", "Unconditional"]] | |
| l = len(keys) | |
| n_cols = min(4, len(keys)) | |
| n_rows = (len(keys) + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 4*n_rows)) | |
| fig.set_dpi(300) | |
| if n_rows == 1: | |
| axes = axes.reshape(1, -1) | |
| def beautiful_text(key): | |
| print(key) | |
| if "auc" in key: | |
| auc = key.split("_")[1] | |
| weight = key.split("_")[3] | |
| if highlight is None: | |
| return f"AUC: $\\mathbf{{{auc}}}$ Weight: {weight}" | |
| else: | |
| return f"AUC: {auc} Weight: $\\mathbf{{{weight}}}$" | |
| if "peak" in key: | |
| peak = key.split("_")[1] | |
| weight = key.split("_")[3] | |
| return f"Peak: {peak} Weight: {weight}" | |
| return key | |
| # Plot distributions | |
| # colors = create_color_gradient({key: get_auc(data[key]) for key in keys}, '#004225','#F02147', '#4B0082') | |
| def get_alpha(idx, n_plots): | |
| """Generate alpha value between 0.3-0.8 based on plot index""" | |
| return 0.5 + (0.4 * idx / (n_plots - 1)) if n_plots > 1 else 0.8 | |
| for idx, key in enumerate(keys): | |
| row, col = idx // n_cols, idx % n_cols | |
| ax = axes[row, col] | |
| # Plot distributions | |
| sns.distplot(data["ori_data"], hist=False, kde=True, | |
| kde_kws={"linewidth": 2, "alpha": 0.9 - get_alpha(idx, l) * 0.5}, color='red', | |
| ax=ax, label=f'Original\n$\overline{{Area}}={auc_orig:.3f}$') | |
| sns.distplot(data["Unconditional"], hist=False, kde=True, | |
| kde_kws={"linewidth": 2, "linestyle":"--", "alpha": 0.9 - get_alpha(idx, l) * 0.5}, | |
| color='#15B01A', ax=ax, #FF4500 GREEN:15B01A | |
| label=f'Unconditional\n$\overline{{Area}}= {auc_uncond:.3f}$') | |
| auc_control = get_auc(data[key]) | |
| sns.distplot(data[key], hist=False, kde=True, | |
| kde_kws={"linewidth": 2, "alpha": get_alpha(idx, l), "linestyle": "--"}, color="#9A0EEA", | |
| ax=ax, label=f'{beautiful_text(key)}\n$\overline{{Area}}= {auc_control:.3f})$') | |
| # ax.set_title(f'{beautiful_text(key)}') | |
| ax.legend() | |
| # Set labels only for first column and last row | |
| if col == 0: ax.set_ylabel('Density') | |
| else: ax.set_ylabel('') | |
| if row == n_rows - 1: ax.set_xlabel('Value') | |
| else: ax.set_xlabel('') | |
| fig.suptitle(f"Kernel Density Estimation of {output_label}", fontsize=16)#, fontweight='bold') | |
| plt.tight_layout() | |
| plt.show() | |
| # save pdf | |
| # plt.savefig(f"./figures/{output_label}_kde.pdf", bbox_inches='tight') | |
| save_pdf(fig, f"./figures/{output_label}_kde.pdf") | |
| plt.close() | |
| ds_name_display = { | |
| "sines": "Synthetic Sine Waves", | |
| "revenue": "Revenue", | |
| "energy": "ETTh", | |
| "fmri": "fMRI", | |
| } | |
| visualization_control_subplots( | |
| data=data, | |
| analysis="kernel", | |
| compare=ori_data.shape[0], | |
| output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control" | |
| ) | |
| # peak control | |
| # data = { | |
| # "ori_data": ori_data[:samples, :, :1], | |
| # "Unconditional": unconditional_data[:samples, :, :1], | |
| # } | |
| # keys = [ | |
| # "peak_0.8_weight_0.1", | |
| # "peak_0.8_weight_0.5", | |
| # "peak_0.8_weight_1.0", | |
| # "peak_1.0_weight_0.1", | |
| # "peak_1.0_weight_0.5", | |
| # "peak_1.0_weight_1.0", | |
| # ] | |
| # for key in keys: | |
| # data[key] = anchor_controled_data[key][:samples, :, :1] | |
| # # print peak | |
| # print(key, " ==> ", anchor_controled_data[key][:samples, :, :1].max()) | |
| # visualization_control( | |
| # data=data, | |
| # analysis="kernel", | |
| # compare=ori_data.shape[0], | |
| # output_label="revenue" | |
| # ) | |
| # # config_mapping = { | |
| # # "sines": { | |
| # # } | |
| # # "revenue": "revenue", | |
| # # "energy": "energy", | |
| # # "fmri": "fMRI", | |
| # # } | |
| # # Evaluate peak detection for different control settings | |
| # peak_accuracies = {} | |
| # for key, data in anchor_controled_data.items(): | |
| # print(f"\nEvaluating {key}") | |
| # metrics, peaks = evaluate_peak_detection( | |
| # data, | |
| # target_peaks=range(0, seq_length, gap), | |
| # window_size=max(1, gap//2), | |
| # min_distance=max(1, gap - 1) | |
| # ) | |
| # peak_accuracies[key] = metrics | |
| # print(f"Accuracy: {metrics['accuracy']:.3f}") | |
| # print(f"Precision: {metrics['precision']:.3f}") | |
| # print(f"Matched peaks: {metrics['matched_peaks']} / {metrics['total_targets']}") | |
| print("="*50) | |