Spaces:
Sleeping
Sleeping
| import argparse | |
| import logging | |
| import multiprocessing | |
| import os | |
| import time | |
| from functools import partial | |
| import numpy as np | |
| import tensorflow as tf | |
| from tqdm import tqdm | |
| from data_reader import DataReader_pred, normalize_batch | |
| from model import UNet | |
| from util import * | |
| tf.compat.v1.disable_eager_execution() | |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
| def read_args(): | |
| """Returns args""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--format", default="numpy", type=str, help="Input data format: numpy or mseed") | |
| parser.add_argument("--batch_size", default=20, type=int, help="Batch size") | |
| parser.add_argument("--output_dir", default="output", help="Output directory (default: output)") | |
| parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)") | |
| parser.add_argument("--sampling_rate", default=100, type=int, help="sampling rate of pred data") | |
| parser.add_argument("--data_dir", default="./Dataset/pred/", help="Input file directory") | |
| parser.add_argument("--data_list", default="./Dataset/pred.csv", help="Input csv file") | |
| parser.add_argument("--plot_figure", action="store_true", help="If plot figure") | |
| parser.add_argument("--save_signal", action="store_true", help="If save denoised signal") | |
| parser.add_argument("--save_noise", action="store_true", help="If save denoised noise") | |
| args = parser.parse_args() | |
| return args | |
| def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None): | |
| current_time = time.strftime("%y%m%d-%H%M%S") | |
| if log_dir is None: | |
| log_dir = os.path.join(args.log_dir, "pred", current_time) | |
| logging.info("Pred log: %s" % log_dir) | |
| # logging.info("Dataset size: {}".format(data_reader.num_data)) | |
| if not os.path.exists(log_dir): | |
| os.makedirs(log_dir) | |
| if args.plot_figure: | |
| figure_dir = os.path.join(log_dir, 'figures') | |
| os.makedirs(figure_dir, exist_ok=True) | |
| if args.save_signal or args.save_noise: | |
| result_dir = os.path.join(log_dir, 'results') | |
| os.makedirs(result_dir, exist_ok=True) | |
| with tf.compat.v1.name_scope('Input_Batch'): | |
| data_batch = data_reader.dataset(args.batch_size) | |
| # model = UNet(input_batch=data_batch, mode='pred') | |
| model = UNet(mode='pred') | |
| sess_config = tf.compat.v1.ConfigProto() | |
| sess_config.gpu_options.allow_growth = True | |
| # sess_config.log_device_placement = False | |
| with tf.compat.v1.Session(config=sess_config) as sess: | |
| saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | |
| init = tf.compat.v1.global_variables_initializer() | |
| sess.run(init) | |
| latest_check_point = tf.train.latest_checkpoint(args.model_dir) | |
| logging.info(f"restoring models: {latest_check_point}") | |
| saver.restore(sess, latest_check_point) | |
| if args.plot_figure: | |
| num_pool = multiprocessing.cpu_count() | |
| else: | |
| num_pool = 2 | |
| multiprocessing.set_start_method('spawn') | |
| pool = multiprocessing.Pool(num_pool) | |
| for _ in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"): | |
| X_batch, fname_batch, t0_batch = sess.run(data_batch) | |
| nbt, nch, nst, nf, nt, nimg = X_batch.shape | |
| X_batch_ = np.reshape(X_batch, [nbt * nch * nst, nf, nt, nimg]) | |
| X_batch_ = normalize_batch(X_batch_) | |
| preds_batch = sess.run( | |
| model.preds, | |
| feed_dict={model.X: X_batch_, model.drop_rate: 0, model.is_training: False}, | |
| ) | |
| preds_batch = np.reshape(preds_batch, [nbt, nch, nst, nf, nt, preds_batch.shape[-1]]) | |
| # preds_batch, X_batch, ratio_batch, fname_batch = sess.run( | |
| # [model.preds, data_batch[0], data_batch[1], data_batch[2]], | |
| # feed_dict={model.drop_rate: 0, model.is_training: False}, | |
| # ) | |
| if args.save_signal or args.save_noise: | |
| save_results( | |
| preds_batch, | |
| X_batch, | |
| fname=[x.decode() for x in fname_batch], | |
| t0=[x.decode() for x in t0_batch], | |
| save_signal=args.save_signal, | |
| save_noise=args.save_noise, | |
| result_dir=result_dir, | |
| ) | |
| if args.plot_figure: | |
| pool.starmap( | |
| partial( | |
| plot_figures, | |
| figure_dir=figure_dir, | |
| ), | |
| zip(preds_batch, X_batch, [x.decode() for x in fname_batch]), | |
| ) | |
| pool.close() | |
| return 0 | |
| def main(args): | |
| logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) | |
| with tf.compat.v1.name_scope('create_inputs'): | |
| data_reader = DataReader_pred( | |
| format=args.format, signal_dir=args.data_dir, signal_list=args.data_list, sampling_rate=args.sampling_rate | |
| ) | |
| logging.info("Dataset Size: {}".format(data_reader.n_signal)) | |
| pred_fn(args, data_reader, log_dir=args.output_dir) | |
| return 0 | |
| if __name__ == '__main__': | |
| args = read_args() | |
| main(args) | |