Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| import numpy as np | |
| import torch | |
| import tensorflow as tf | |
| # Add parent directory to path | |
| sys.path.append(os.path.join(os.path.dirname('__file__'), '../')) | |
| # Local imports | |
| from experiment import run | |
| from utils.context_fid import Context_FID | |
| from utils.cross_correlation import CrossCorrelLoss | |
| from utils.metric_utils import display_scores | |
| from utils.discriminative_metric import discriminative_score_metrics | |
| from utils.predictive_metric import predictive_score_metrics | |
| # Suppress warnings | |
| # Configure GPU memory growth | |
| gpus = tf.config.experimental.list_physical_devices('GPU') | |
| if gpus: | |
| try: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| except RuntimeError as e: | |
| print(e) | |
| # Global settings | |
| iterations = 5 | |
| enable_fid = 0 | |
| enable_corr = 0 | |
| enable_dis = 1 | |
| enable_pred = 1 | |
| # all_results = {} | |
| # for config_path in [ | |
| # "./config/modified/sines.yaml", | |
| # "./config/modified/revenue-baseline-365.yaml", | |
| # "./config/modified/energy.yaml", | |
| # "./config/modified/fmri.yaml", | |
| # "./config/modified/96/energy.yaml", | |
| # "./config/modified/192/energy.yaml", | |
| # "./config/modified/384/energy.yaml", | |
| # "./config/modified/96/fmri.yaml", | |
| # "./config/modified/192/fmri.yaml", | |
| # "./config/modified/384/fmri.yaml", | |
| # "./config/modified/96/sines.yaml", | |
| # "./config/modified/192/sines.yaml", | |
| # "./config/modified/384/sines.yaml", | |
| # "./config/modified/192/revenue.yaml", | |
| # "./config/modified/96/revenue.yaml", | |
| # "./config/modified/384/revenue.yaml", | |
| # ]: | |
| # class Args: | |
| # config_path = config_path | |
| # gpu = 0 | |
| # results, dataset_name, seq_length = run(Args()) | |
| # all_results[config_path] = (results, dataset_name, seq_length) | |
| # python run.py ./config/modified/energy.yaml | |
| # python run.py ./config/modified/fmri.yaml | |
| # python run.py ./config/modified/sines.yaml | |
| # python run.py ./config/modified/revenue-baseline-365.yaml | |
| # python run.py ./config/modified/96/energy.yaml | |
| # python run.py ./config/modified/192/energy.yaml | |
| # python run.py ./config/modified/384/energy.yaml | |
| # python run.py ./config/modified/96/fmri.yaml | |
| # python run.py ./config/modified/192/fmri.yaml | |
| # python run.py ./config/modified/384/fmri.yaml | |
| # python run.py ./config/modified/96/sines.yaml | |
| # python run.py ./config/modified/192/sines.yaml | |
| # python run.py ./config/modified/384/sines.yaml | |
| # python run.py ./config/modified/192/revenue.yaml | |
| # python run.py ./config/modified/96/revenue.yaml | |
| # python run.py ./config/modified/384/revenue.yaml | |
| ds_name_display = { | |
| "sines": "Sine", | |
| "revenue": "Revenue", | |
| "energy": "ETTh", | |
| "fmri": "fMRI", | |
| } | |
| def random_choice(size, num_select=100): | |
| select_idx = np.random.randint(low=0, high=size, size=(num_select,)) | |
| return select_idx | |
| def compute_metrics(ori_data, fake_data, iterations=5, data_name='sines', data_len=24, key="unconditional"): | |
| if enable_dis: | |
| discriminative_score = [] | |
| for i in range(iterations): | |
| temp_disc, fake_acc, real_acc = discriminative_score_metrics(ori_data[:], fake_data[:ori_data.shape[0]]) | |
| discriminative_score.append(temp_disc) | |
| print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\n') | |
| mean, sigma = display_scores(discriminative_score) | |
| content = f'disc {data_name} {key} {data_len} {mean} {sigma}' | |
| with open(f'log {data_name}.txt', 'a+') as file: | |
| file.write(content + '\n') | |
| if enable_pred: | |
| predictive_score = [] | |
| for i in range(iterations): | |
| temp_pred = predictive_score_metrics(ori_data, fake_data[:ori_data.shape[0]]) | |
| predictive_score.append(temp_pred) | |
| print(i, ' epoch: ', temp_pred, '\n') | |
| mean, sigma = display_scores(predictive_score) | |
| content = f'pred {data_name} {key} {data_len} {mean} {sigma}' | |
| with open(f'log {data_name}.txt', 'a+') as file: | |
| file.write(content + '\n') | |
| if enable_fid: | |
| context_fid_score = [] | |
| for i in range(iterations): | |
| context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]]) | |
| context_fid_score.append(context_fid) | |
| print(f'Iter {i}: ', 'context-fid =', context_fid, '\n') | |
| mean, sigma = display_scores(context_fid_score) | |
| content = f'fid {data_name} {key} {data_len} {mean} {sigma}' | |
| with open(f'log {data_name}.txt', 'a+') as file: | |
| file.write(content + '\n') | |
| if enable_corr: | |
| x_real = torch.from_numpy(ori_data) | |
| x_fake = torch.from_numpy(fake_data) | |
| correlational_score = [] | |
| size = int(x_real.shape[0] / iterations) | |
| for i in range(iterations): | |
| real_idx = random_choice(x_real.shape[0], size) | |
| fake_idx = random_choice(x_fake.shape[0], size) | |
| corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss') | |
| loss = corr.compute(x_fake[fake_idx, :, :]) | |
| correlational_score.append(loss.item()) | |
| print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n') | |
| mean, sigma = display_scores(correlational_score) | |
| content = f'corr {data_name} {key} {data_len} {mean} {sigma}' | |
| with open(f'log {data_name}.txt', 'a+') as file: | |
| file.write(content + '\n') | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('config_path', type=str, default="./config/modified/sines.yaml") | |
| parser.add_argument('--gpu', type=int, default=0) | |
| # class Args: | |
| # config_path = config_path | |
| # gpu = 0 | |
| results, dataset_name, seq_length = run(parser.parse_args()) | |
| # config_path = parser.parse_args().config_path | |
| # results, dataset_name, seq_length = all_results[config_path] | |
| ori_data = results["ori_data"] | |
| unconditional_data = results["unconditional"] | |
| sum_controled_data = results["sum_controlled"] | |
| anchor_controled_data = results["anchor_controlled"] | |
| compute_metrics(ori_data, unconditional_data, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key="unconditional") | |
| for key, value in sum_controled_data.items(): | |
| compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key) | |
| for key, value in anchor_controled_data.items(): | |
| compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key) | |