Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import time | |
| os.environ["WANDB_ENABLED"] = "false" | |
| from engine.solver import Trainer | |
| from data.build_dataloader import build_dataloader | |
| from data.build_dataloader import build_dataloader_cond | |
| from utils.io_utils import load_yaml_config, instantiate_from_config | |
| import warnings | |
| warnings.simplefilter("ignore", UserWarning) | |
| import numpy as np | |
| import pickle | |
| from pathlib import Path | |
| def load_cached_results(cache_dir): | |
| results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}} | |
| for cache_file in cache_dir.glob("*.pkl"): | |
| with open(cache_file, "rb") as f: | |
| key = cache_file.stem | |
| # if key=="unconditional": | |
| # continue | |
| if key == "unconditional": | |
| results["unconditional"] = pickle.load(f) | |
| elif key.startswith("sum_"): | |
| param = key[4:] # Remove 'sum_' prefix | |
| results["sum_controlled"][param] = pickle.load(f) | |
| elif key.startswith("anchor_"): | |
| param = key[7:] # Remove 'anchor_' prefix | |
| results["anchor_controlled"][param] = pickle.load(f) | |
| return results | |
| def save_result(cache_dir, key, subkey, data): | |
| return | |
| if subkey: | |
| filename = f"{key}_{subkey}.pkl" | |
| else: | |
| filename = f"{key}.pkl" | |
| with open(cache_dir / filename, "wb") as f: | |
| pickle.dump(data, f) | |
| class Arguments: | |
| def __init__(self, config_path, gpu=0) -> None: | |
| self.config_path = config_path | |
| # self.config_path = "./config/control/revenue-baseline-sine.yaml" | |
| self.save_dir = ( | |
| "../../../data/" + os.path.basename(self.config_path).split(".")[0] | |
| ) | |
| self.gpu = gpu | |
| os.makedirs(self.save_dir, exist_ok=True) | |
| self.mode = "infill" | |
| self.missing_ratio = 0.95 | |
| self.milestone = 10 | |
| import argparse | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Controlled Sampling") | |
| parser.add_argument( | |
| "--config_path", type=str, default="./config/modified/energy.yaml" | |
| ) | |
| parser.add_argument("--gpu", type=int, default=0) | |
| return parser.parse_args() | |
| def run(run_args): | |
| args = Arguments(run_args.config_path, run_args.gpu) | |
| configs = load_yaml_config(args.config_path) | |
| device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.set_device(args.gpu) | |
| dl_info = build_dataloader(configs, args) | |
| model = instantiate_from_config(configs["model"]).to(device) | |
| trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) | |
| # args.milestone | |
| trainer.load("10") | |
| dataset = dl_info["dataset"] | |
| test_dl_info = build_dataloader_cond(configs, args) | |
| test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] | |
| coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
| stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
| sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] | |
| seq_length, feature_dim = test_dataset.window, test_dataset.var_num | |
| dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] | |
| mapper = { | |
| "sines": "sines", | |
| "revenue": "revenue", | |
| "energy": "energy", | |
| "fmri": "fMRI", | |
| } | |
| gap = seq_length // 5 | |
| if seq_length in [96, 192, 384]: | |
| ori_data = np.load( | |
| os.path.join( | |
| "../../../data/train/",str(seq_length), | |
| dataset_name, | |
| "samples", | |
| f'{mapper[dataset_name].replace("sines", "sine")}_norm_truth_{seq_length}_train.npy', | |
| ) | |
| ) | |
| masks = np.load( | |
| os.path.join( | |
| "../../../data/train/",str(seq_length), | |
| dataset_name, | |
| "samples", | |
| f'{mapper[dataset_name].replace("sines", "sine")}_masking_{seq_length}.npy', | |
| ) | |
| ) | |
| else: | |
| ori_data = np.load( | |
| os.path.join( | |
| "../../../data/train/", | |
| dataset_name, | |
| "samples", | |
| f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy", | |
| ) | |
| ) | |
| masks = np.load( | |
| os.path.join( | |
| "../../../data/train/", | |
| dataset_name, | |
| "samples", | |
| f"{mapper[dataset_name]}_masking_{seq_length}.npy", | |
| ) | |
| ) | |
| sample_num, _, _ = masks.shape | |
| # observed = ori_data[:sample_num] * masks | |
| ori_data = ori_data[:sample_num] | |
| sampling_size = min(1000, len(test_dataset), sample_num) | |
| batch_size = 500 | |
| print(f"Sampling size: {sampling_size}, Batch size: {batch_size}") | |
| ### Cache file path | |
| cache_dir = Path(f"../../../data/cache/{dataset_name}_{seq_length}") | |
| cache_dir.mkdir(exist_ok=True) | |
| # results = load_cached_results(cache_dir) | |
| results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}} | |
| def measure_inference_time(func, *args, **kwargs): | |
| start_time = time.time() | |
| result = func(*args, **kwargs) | |
| end_time = time.time() | |
| return result, (end_time - start_time) | |
| timing_results = {} | |
| ### Unconditional sampling | |
| if results["unconditional"] is None: | |
| print("Generating unconditional data...") | |
| results["unconditional"], timing = measure_inference_time( | |
| trainer.control_sample, | |
| num=sampling_size, | |
| size_every=batch_size, | |
| shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {}, | |
| "coef": coef, | |
| "learning_rate": stepsize, | |
| }, | |
| ) | |
| timing_results["unconditional"] = timing / sampling_size | |
| save_result(cache_dir, "unconditional", "", results["unconditional"]) | |
| ### Different AUC values | |
| auc_weights = [10] | |
| auc_values = [-100, 20, 50, 150] # -200, -150, -100, -50, 0, 20, 30, 50, 100, 150 | |
| for auc in auc_values: | |
| for weight in auc_weights: | |
| key = f"auc_{auc}_weight_{weight}" | |
| if key not in results["sum_controlled"]: | |
| print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
| results["sum_controlled"][key], timing = measure_inference_time( | |
| trainer.control_sample, | |
| num=sampling_size, | |
| size_every=batch_size, | |
| shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
| "coef": coef, | |
| "learning_rate": stepsize, | |
| }, | |
| ) | |
| timing_results[f"sum_controlled_{key}"] = timing / sampling_size | |
| save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
| ### Different AUC weights | |
| auc_weights = [1, 10, 50, 100] | |
| auc_values = [-100] | |
| for auc in auc_values: | |
| for weight in auc_weights: | |
| key = f"auc_{auc}_weight_{weight}" | |
| if key not in results["sum_controlled"]: | |
| print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
| results["sum_controlled"][key], timing = measure_inference_time( | |
| trainer.control_sample, | |
| num=sampling_size, | |
| size_every=batch_size, | |
| shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
| "coef": coef, | |
| "learning_rate": stepsize, | |
| }, | |
| ) | |
| timing_results[f"sum_controlled_{key}"] = timing / (sampling_size) | |
| save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
| ### Different AUC segments | |
| auc_weights = [10] | |
| auc_values = [150] | |
| auc_average = 10 | |
| auc_segments = ((gap, 2 * gap), (2 * gap, 3 * gap), (3 * gap, 4 * gap)) | |
| # for auc in auc_values: | |
| # for weight in auc_weights: | |
| # for segment in auc_segments: | |
| auc = auc_values[0] | |
| weight = auc_weights[0] | |
| # segment = auc_segments[0] | |
| for segment in auc_segments: | |
| key = f"auc_{auc}_weight_{weight}_segment_{segment[0]}_{segment[1]}" | |
| if key not in results["sum_controlled"]: | |
| print( | |
| f"Generating sum controlled data - AUC: {auc}, Weight: {weight}, Segment: {segment}" | |
| ) | |
| results["sum_controlled"][key], timing = measure_inference_time( | |
| trainer.control_sample, | |
| num=sampling_size, | |
| size_every=batch_size, | |
| shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": { | |
| "auc": auc_average * (segment[1] - segment[0]), # / seq_length, | |
| "auc_weight": weight, | |
| "segment": [segment], | |
| }, | |
| "coef": coef, | |
| "learning_rate": stepsize, | |
| }, | |
| ) | |
| timing_results[f"sum_controlled_{key}"] = timing / sampling_size | |
| save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
| # Different anchors | |
| anchor_values = [-0.8, 0.6, 1.0] | |
| anchor_weights = [0.01, 0.01, 0.5, 1.0] | |
| for peak in anchor_values: | |
| for weight in anchor_weights: | |
| key = f"peak_{peak}_weight_{weight}" | |
| if key not in results["anchor_controlled"]: | |
| mask = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
| mask[gap // 2 :: gap, 0] = weight | |
| target = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
| target[gap // 2 :: gap, 0] = peak | |
| print(f"Anchor controlled data - Peak: {peak}, Weight: {weight}") | |
| results["anchor_controlled"][key], timing = measure_inference_time( | |
| trainer.control_sample, | |
| num=sampling_size, | |
| size_every=batch_size, | |
| shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {}, # "auc": -50, "auc_weight": 10.0}, | |
| "coef": coef, | |
| "learning_rate": stepsize, | |
| }, | |
| target=target, | |
| partial_mask=mask, | |
| ) | |
| timing_results[f"anchor_controlled_{key}"] = timing / sampling_size | |
| save_result(cache_dir, "anchor", key, results["anchor_controlled"][key]) | |
| ### Rerun Unconditional sampling | |
| if results["unconditional"] is None: | |
| print("Generating unconditional data...") | |
| results["unconditional"], timing = measure_inference_time( | |
| trainer.control_sample, | |
| num=sampling_size, | |
| size_every=batch_size, | |
| shape=[seq_length, feature_dim], | |
| model_kwargs={ | |
| "gradient_control_signal": {}, | |
| "coef": coef, | |
| "learning_rate": stepsize, | |
| }, | |
| ) | |
| timing_results["unconditional"] = timing / sampling_size | |
| save_result(cache_dir, "unconditional", "", results["unconditional"]) | |
| # After all sampling is done, print timing results | |
| print("\nAverage Inference Time per Sample:") | |
| print("-" * 40) | |
| for key, time_per_sample in timing_results.items(): | |
| print(f"{key}: {time_per_sample:.4f} seconds") | |
| # return results, dataset_name, seq_length | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| run(args) | |