Spaces:
Running
Running
Commit
·
81c99dc
0
Parent(s):
init
Browse files- Dockerfile +26 -0
- LICENSE +21 -0
- deepdenoiser/__init__.py +0 -0
- deepdenoiser/app.py +180 -0
- deepdenoiser/data_reader.py +816 -0
- deepdenoiser/model.py +495 -0
- deepdenoiser/predict.py +136 -0
- deepdenoiser/train.py +557 -0
- deepdenoiser/util.py +875 -0
- docs/README.md +60 -0
- docs/example_batch_prediction.ipynb +0 -0
- docs/example_interactive.ipynb +0 -0
- env.yml +19 -0
- mkdocs.yml +18 -0
- requirements.txt +5 -0
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM tensorflow/tensorflow
|
| 2 |
+
|
| 3 |
+
# Create the environment:
|
| 4 |
+
# COPY env.yml /app
|
| 5 |
+
# RUN conda env create --name cs329s --file=env.yml
|
| 6 |
+
# Make RUN commands use the new environment:
|
| 7 |
+
# SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]
|
| 8 |
+
|
| 9 |
+
RUN pip install tqdm obspy pandas
|
| 10 |
+
RUN pip install uvicorn fastapi kafka-python
|
| 11 |
+
|
| 12 |
+
WORKDIR /opt
|
| 13 |
+
|
| 14 |
+
# Copy files
|
| 15 |
+
COPY deepdenoiser /opt/deepdenoiser
|
| 16 |
+
# COPY model /opt/model
|
| 17 |
+
RUN wget https://github.com/AI4EPS/models/releases/download/DeepDenoiser/model.tar && tar -xvf model.tar && rm model.tar
|
| 18 |
+
|
| 19 |
+
# Expose API port
|
| 20 |
+
EXPOSE 8000
|
| 21 |
+
|
| 22 |
+
ENV PYTHONUNBUFFERED=1
|
| 23 |
+
|
| 24 |
+
# Start API server
|
| 25 |
+
#ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
|
| 26 |
+
ENTRYPOINT ["uvicorn", "--app-dir", "deepdenoiser", "app:app", "--reload", "--port", "7860", "--host", "0.0.0.0"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 Weiqiang Zhu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
deepdenoiser/__init__.py
ADDED
|
File without changes
|
deepdenoiser/app.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import defaultdict, namedtuple
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from json import dumps
|
| 5 |
+
from typing import Any, AnyStr, Dict, List, NamedTuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import requests
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
+
from kafka import KafkaProducer
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
import scipy
|
| 14 |
+
from scipy.interpolate import interp1d
|
| 15 |
+
|
| 16 |
+
from model import UNet
|
| 17 |
+
|
| 18 |
+
tf.compat.v1.disable_eager_execution()
|
| 19 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 20 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
JSONObject = Dict[AnyStr, Any]
|
| 22 |
+
JSONArray = List[Any]
|
| 23 |
+
JSONStructure = Union[JSONArray, JSONObject]
|
| 24 |
+
|
| 25 |
+
app = FastAPI()
|
| 26 |
+
X_SHAPE = [3000, 1, 3]
|
| 27 |
+
SAMPLING_RATE = 100
|
| 28 |
+
|
| 29 |
+
# load model
|
| 30 |
+
model = UNet(mode="pred")
|
| 31 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 32 |
+
sess_config.gpu_options.allow_growth = True
|
| 33 |
+
|
| 34 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
| 35 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 36 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 37 |
+
sess.run(init)
|
| 38 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802")
|
| 39 |
+
print(f"restoring model {latest_check_point}")
|
| 40 |
+
saver.restore(sess, latest_check_point)
|
| 41 |
+
|
| 42 |
+
# Kafak producer
|
| 43 |
+
use_kafka = False
|
| 44 |
+
# BROKER_URL = 'localhost:9092'
|
| 45 |
+
# BROKER_URL = 'my-kafka-headless:9092'
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
print("Connecting to k8s kafka")
|
| 49 |
+
BROKER_URL = "quakeflow-kafka-headless:9092"
|
| 50 |
+
producer = KafkaProducer(
|
| 51 |
+
bootstrap_servers=[BROKER_URL],
|
| 52 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 53 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 54 |
+
)
|
| 55 |
+
use_kafka = True
|
| 56 |
+
print("k8s kafka connection success!")
|
| 57 |
+
except BaseException:
|
| 58 |
+
print("k8s Kafka connection error")
|
| 59 |
+
try:
|
| 60 |
+
print("Connecting to local kafka")
|
| 61 |
+
producer = KafkaProducer(
|
| 62 |
+
bootstrap_servers=["localhost:9092"],
|
| 63 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 64 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 65 |
+
)
|
| 66 |
+
use_kafka = True
|
| 67 |
+
print("local kafka connection success!")
|
| 68 |
+
except BaseException:
|
| 69 |
+
print("local Kafka connection error")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def normalize_batch(data, window=200):
|
| 73 |
+
"""
|
| 74 |
+
data: nbn, nf, nt, 2
|
| 75 |
+
"""
|
| 76 |
+
assert len(data.shape) == 4
|
| 77 |
+
shift = window // 2
|
| 78 |
+
nbt, nf, nt, nimg = data.shape
|
| 79 |
+
|
| 80 |
+
## std in slide windows
|
| 81 |
+
data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
| 82 |
+
t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
|
| 83 |
+
# print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
|
| 84 |
+
std = np.zeros([nbt, len(t)])
|
| 85 |
+
mean = np.zeros([nbt, len(t)])
|
| 86 |
+
for i in range(std.shape[1]):
|
| 87 |
+
std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
| 88 |
+
mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
| 89 |
+
|
| 90 |
+
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
|
| 91 |
+
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
|
| 92 |
+
|
| 93 |
+
## normalize data with interplated std
|
| 94 |
+
t_interp = np.arange(nt, dtype="int")
|
| 95 |
+
std_interp = interp1d(t, std, kind="slinear")(t_interp)
|
| 96 |
+
std_interp[std_interp == 0] = 1.0
|
| 97 |
+
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
| 98 |
+
|
| 99 |
+
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
|
| 100 |
+
|
| 101 |
+
if len(t) > 3: ##need to address this normalization issue in training
|
| 102 |
+
data /= 2.0
|
| 103 |
+
|
| 104 |
+
return data
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_prediction(meta):
|
| 108 |
+
|
| 109 |
+
FS = 100
|
| 110 |
+
NPERSEG = 30
|
| 111 |
+
NFFT = 60
|
| 112 |
+
|
| 113 |
+
vec = np.array(meta.vec) # [batch, nt, chn]
|
| 114 |
+
nbt, nt, nch = vec.shape
|
| 115 |
+
vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt]
|
| 116 |
+
vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt]
|
| 117 |
+
|
| 118 |
+
if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000
|
| 119 |
+
vec = vec[..., :-1]
|
| 120 |
+
|
| 121 |
+
if meta.dt != 0.01:
|
| 122 |
+
t = np.linspace(0, 1, len(vec))
|
| 123 |
+
t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS)))
|
| 124 |
+
vec = interp1d(t, vec, kind="slinear")(t_interp)
|
| 125 |
+
|
| 126 |
+
# sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
|
| 127 |
+
# vec = scipy.signal.sosfilt(sos, vec)
|
| 128 |
+
f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary='zeros')
|
| 129 |
+
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2]
|
| 130 |
+
noisy_signal[np.isnan(noisy_signal)] = 0
|
| 131 |
+
noisy_signal[np.isinf(noisy_signal)] = 0
|
| 132 |
+
X_input = normalize_batch(noisy_signal)
|
| 133 |
+
|
| 134 |
+
feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False}
|
| 135 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
| 136 |
+
|
| 137 |
+
_, denoised_signal = scipy.signal.istft(
|
| 138 |
+
(noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0],
|
| 139 |
+
fs=FS,
|
| 140 |
+
nperseg=NPERSEG,
|
| 141 |
+
nfft=NFFT,
|
| 142 |
+
boundary='zeros',
|
| 143 |
+
)
|
| 144 |
+
# _, denoised_noise = scipy.signal.istft(
|
| 145 |
+
# (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1],
|
| 146 |
+
# fs=FS,
|
| 147 |
+
# nperseg=NPERSEG,
|
| 148 |
+
# nfft=NFFT,
|
| 149 |
+
# boundary='zeros',
|
| 150 |
+
# )
|
| 151 |
+
|
| 152 |
+
denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt])
|
| 153 |
+
denoised_signal = np.transpose(denoised_signal, [0, 2, 1])
|
| 154 |
+
|
| 155 |
+
result = meta.copy()
|
| 156 |
+
result.vec = denoised_signal.tolist()
|
| 157 |
+
return result
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class Data(BaseModel):
|
| 161 |
+
# id: Union[List[str], str]
|
| 162 |
+
# timestamp: Union[List[str], str]
|
| 163 |
+
# vec: Union[List[List[List[float]]], List[List[float]]]
|
| 164 |
+
id: List[str]
|
| 165 |
+
timestamp: List[str]
|
| 166 |
+
vec: List[List[List[float]]]
|
| 167 |
+
dt: float = 0.01
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@app.post("/predict")
|
| 171 |
+
def predict(data: Data):
|
| 172 |
+
|
| 173 |
+
denoised = get_prediction(data)
|
| 174 |
+
|
| 175 |
+
return denoised
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@app.get("/healthz")
|
| 179 |
+
def healthz():
|
| 180 |
+
return {"status": "ok"}
|
deepdenoiser/data_reader.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import scipy.signal
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
|
| 6 |
+
pd.options.mode.chained_assignment = None
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import threading
|
| 10 |
+
|
| 11 |
+
import obspy
|
| 12 |
+
from scipy.interpolate import interp1d
|
| 13 |
+
|
| 14 |
+
tf.compat.v1.disable_eager_execution()
|
| 15 |
+
# from tensorflow.python.ops.linalg_ops import norm
|
| 16 |
+
# from tensorflow.python.util import nest
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Config:
|
| 20 |
+
seed = 100
|
| 21 |
+
n_class = 2
|
| 22 |
+
fs = 100
|
| 23 |
+
dt = 1.0 / fs
|
| 24 |
+
freq_range = [0, fs / 2]
|
| 25 |
+
time_range = [0, 30]
|
| 26 |
+
nperseg = 30
|
| 27 |
+
nfft = 60
|
| 28 |
+
plot = False
|
| 29 |
+
nt = 3000
|
| 30 |
+
X_shape = [31, 201, 2]
|
| 31 |
+
Y_shape = [31, 201, n_class]
|
| 32 |
+
signal_shape = [31, 201]
|
| 33 |
+
noise_shape = signal_shape
|
| 34 |
+
use_seed = False
|
| 35 |
+
queue_size = 10
|
| 36 |
+
noise_mean = 2
|
| 37 |
+
noise_std = 1
|
| 38 |
+
# noise_low = 1
|
| 39 |
+
# noise_high = 5
|
| 40 |
+
use_buffer = True
|
| 41 |
+
snr_threshold = 10
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# %%
|
| 45 |
+
# def normalize(data, window=3000):
|
| 46 |
+
# """
|
| 47 |
+
# data: nsta, chn, nt
|
| 48 |
+
# """
|
| 49 |
+
# shift = window//2
|
| 50 |
+
# nt = len(data)
|
| 51 |
+
|
| 52 |
+
# ## std in slide windows
|
| 53 |
+
# data_pad = np.pad(data, ((window//2, window//2)), mode="reflect")
|
| 54 |
+
# t = np.arange(0, nt, shift, dtype="int")
|
| 55 |
+
# # print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
|
| 56 |
+
# std = np.zeros(len(t))
|
| 57 |
+
# mean = np.zeros(len(t))
|
| 58 |
+
# for i in range(len(std)):
|
| 59 |
+
# std[i] = np.std(data_pad[i*shift:i*shift+window])
|
| 60 |
+
# mean[i] = np.mean(data_pad[i*shift:i*shift+window])
|
| 61 |
+
|
| 62 |
+
# t = np.append(t, nt)
|
| 63 |
+
# std = np.append(std, [np.std(data_pad[-window:])])
|
| 64 |
+
# mean = np.append(mean, [np.mean(data_pad[-window:])])
|
| 65 |
+
|
| 66 |
+
# # print(t)
|
| 67 |
+
# ## normalize data with interplated std
|
| 68 |
+
# t_interp = np.arange(nt, dtype="int")
|
| 69 |
+
# std_interp = interp1d(t, std, kind="slinear")(t_interp)
|
| 70 |
+
# mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
| 71 |
+
# data = (data - mean_interp)/(std_interp)
|
| 72 |
+
# return data, std_interp
|
| 73 |
+
|
| 74 |
+
# %%
|
| 75 |
+
def normalize(data, window=200):
|
| 76 |
+
"""
|
| 77 |
+
data: nsta, chn, nt
|
| 78 |
+
"""
|
| 79 |
+
shift = window // 2
|
| 80 |
+
nt = data.shape[1]
|
| 81 |
+
|
| 82 |
+
## std in slide windows
|
| 83 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
| 84 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 85 |
+
# print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
|
| 86 |
+
std = np.zeros(len(t))
|
| 87 |
+
mean = np.zeros(len(t))
|
| 88 |
+
for i in range(len(std)):
|
| 89 |
+
std[i] = np.std(data_pad[:, i * shift : i * shift + window, :])
|
| 90 |
+
mean[i] = np.mean(data_pad[:, i * shift : i * shift + window, :])
|
| 91 |
+
|
| 92 |
+
t = np.append(t, nt)
|
| 93 |
+
std = np.append(std, [np.std(data_pad[:, -window:, :])])
|
| 94 |
+
mean = np.append(mean, [np.mean(data_pad[:, -window:, :])])
|
| 95 |
+
# print(t)
|
| 96 |
+
## normalize data with interplated std
|
| 97 |
+
t_interp = np.arange(nt, dtype="int")
|
| 98 |
+
std_interp = interp1d(t, std, kind="slinear")(t_interp)
|
| 99 |
+
std_interp[std_interp == 0] = 1.0
|
| 100 |
+
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
| 101 |
+
data = (data - mean_interp[np.newaxis, :, np.newaxis]) / std_interp[np.newaxis, :, np.newaxis]
|
| 102 |
+
return data, std_interp
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def normalize_batch(data, window=200):
|
| 106 |
+
"""
|
| 107 |
+
data: nbn, nf, nt, 2
|
| 108 |
+
"""
|
| 109 |
+
assert len(data.shape) == 4
|
| 110 |
+
shift = window // 2
|
| 111 |
+
nbt, nf, nt, nimg = data.shape
|
| 112 |
+
|
| 113 |
+
## std in slide windows
|
| 114 |
+
data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
| 115 |
+
t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
|
| 116 |
+
std = np.zeros([nbt, len(t)])
|
| 117 |
+
mean = np.zeros([nbt, len(t)])
|
| 118 |
+
for i in range(std.shape[1]):
|
| 119 |
+
std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
| 120 |
+
mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
| 121 |
+
|
| 122 |
+
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
|
| 123 |
+
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
|
| 124 |
+
|
| 125 |
+
## normalize data with interplated std
|
| 126 |
+
t_interp = np.arange(nt, dtype="int")
|
| 127 |
+
std_interp = interp1d(t, std, kind="slinear")(t_interp) ##nbt, nt
|
| 128 |
+
std_interp[std_interp == 0] = 1.0
|
| 129 |
+
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
| 130 |
+
|
| 131 |
+
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
|
| 132 |
+
|
| 133 |
+
if len(t) > 3: ##need to address this normalization issue in training
|
| 134 |
+
data /= 2.0
|
| 135 |
+
|
| 136 |
+
return data
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# %%
|
| 140 |
+
def py_func_decorator(output_types=None, output_shapes=None, name=None):
|
| 141 |
+
def decorator(func):
|
| 142 |
+
def call(*args, **kwargs):
|
| 143 |
+
nonlocal output_shapes
|
| 144 |
+
# flat_output_types = nest.flatten(output_types)
|
| 145 |
+
flat_output_types = tf.nest.flatten(output_types)
|
| 146 |
+
# flat_values = tf.py_func(
|
| 147 |
+
flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
|
| 148 |
+
if output_shapes is not None:
|
| 149 |
+
for v, s in zip(flat_values, output_shapes):
|
| 150 |
+
v.set_shape(s)
|
| 151 |
+
# return nest.pack_sequence_as(output_types, flat_values)
|
| 152 |
+
return tf.nest.pack_sequence_as(output_types, flat_values)
|
| 153 |
+
|
| 154 |
+
return call
|
| 155 |
+
|
| 156 |
+
return decorator
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None):
|
| 160 |
+
dataset = tf.data.Dataset.range(len(iterator))
|
| 161 |
+
|
| 162 |
+
@py_func_decorator(output_types, output_shapes, name=name)
|
| 163 |
+
def index_to_entry(idx):
|
| 164 |
+
return iterator[idx]
|
| 165 |
+
|
| 166 |
+
return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class DataReader(object):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
signal_dir=None,
|
| 173 |
+
signal_list=None,
|
| 174 |
+
noise_dir=None,
|
| 175 |
+
noise_list=None,
|
| 176 |
+
queue_size=None,
|
| 177 |
+
coord=None,
|
| 178 |
+
config=Config(),
|
| 179 |
+
):
|
| 180 |
+
|
| 181 |
+
self.config = config
|
| 182 |
+
|
| 183 |
+
signal_list = pd.read_csv(signal_list, header=0)
|
| 184 |
+
noise_list = pd.read_csv(noise_list, header=0)
|
| 185 |
+
|
| 186 |
+
self.signal = signal_list
|
| 187 |
+
self.noise = noise_list
|
| 188 |
+
self.n_signal = len(self.signal)
|
| 189 |
+
|
| 190 |
+
self.signal_dir = signal_dir
|
| 191 |
+
self.noise_dir = noise_dir
|
| 192 |
+
|
| 193 |
+
self.X_shape = config.X_shape
|
| 194 |
+
self.Y_shape = config.Y_shape
|
| 195 |
+
self.n_class = config.n_class
|
| 196 |
+
|
| 197 |
+
self.coord = coord
|
| 198 |
+
self.threads = []
|
| 199 |
+
self.queue_size = queue_size
|
| 200 |
+
|
| 201 |
+
self.add_queue()
|
| 202 |
+
self.buffer_signal = {}
|
| 203 |
+
self.buffer_noise = {}
|
| 204 |
+
self.buffer_channels_signal = {}
|
| 205 |
+
self.buffer_channels_noise = {}
|
| 206 |
+
|
| 207 |
+
def add_queue(self):
|
| 208 |
+
with tf.device('/cpu:0'):
|
| 209 |
+
self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 210 |
+
self.target_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 211 |
+
self.queue = tf.queue.PaddingFIFOQueue(
|
| 212 |
+
self.queue_size, ['float32', 'float32'], shapes=[self.config.X_shape, self.config.Y_shape]
|
| 213 |
+
)
|
| 214 |
+
self.enqueue = self.queue.enqueue([self.sample_placeholder, self.target_placeholder])
|
| 215 |
+
return 0
|
| 216 |
+
|
| 217 |
+
def dequeue(self, num_elements):
|
| 218 |
+
output = self.queue.dequeue_many(num_elements)
|
| 219 |
+
return output
|
| 220 |
+
|
| 221 |
+
def get_snr(self, data, itp, dit=300):
|
| 222 |
+
tmp_std = np.std(data[itp - dit : itp])
|
| 223 |
+
if tmp_std > 0:
|
| 224 |
+
return np.std(data[itp : itp + dit]) / tmp_std
|
| 225 |
+
else:
|
| 226 |
+
return 0
|
| 227 |
+
|
| 228 |
+
def add_event(self, sample, channels, j):
|
| 229 |
+
while np.random.uniform(0, 1) < 0.2:
|
| 230 |
+
shift = None
|
| 231 |
+
if channels not in self.buffer_channels_signal:
|
| 232 |
+
self.buffer_channels_signal[channels] = self.signal[self.signal['channels'] == channels]
|
| 233 |
+
fname = os.path.join(self.signal_dir, self.buffer_channels_signal[channels].sample(n=1).iloc[0]['fname'])
|
| 234 |
+
try:
|
| 235 |
+
if fname not in self.buffer_signal:
|
| 236 |
+
meta = np.load(fname)
|
| 237 |
+
data_FT = []
|
| 238 |
+
snr = []
|
| 239 |
+
for i in range(3):
|
| 240 |
+
tmp_data = meta['data'][:, i]
|
| 241 |
+
tmp_itp = meta['itp']
|
| 242 |
+
snr.append(self.get_snr(tmp_data, tmp_itp))
|
| 243 |
+
tmp_data -= np.mean(tmp_data)
|
| 244 |
+
f, t, tmp_FT = scipy.signal.stft(
|
| 245 |
+
tmp_data,
|
| 246 |
+
fs=self.config.fs,
|
| 247 |
+
nperseg=self.config.nperseg,
|
| 248 |
+
nfft=self.config.nfft,
|
| 249 |
+
boundary='zeros',
|
| 250 |
+
)
|
| 251 |
+
data_FT.append(tmp_FT)
|
| 252 |
+
data_FT = np.stack(data_FT, axis=-1)
|
| 253 |
+
self.buffer_signal[fname] = {
|
| 254 |
+
'data_FT': data_FT,
|
| 255 |
+
'itp': tmp_itp,
|
| 256 |
+
'channels': meta['channels'],
|
| 257 |
+
'snr': snr,
|
| 258 |
+
}
|
| 259 |
+
meta_signal = self.buffer_signal[fname]
|
| 260 |
+
except:
|
| 261 |
+
logging.error("Failed reading signal: {}".format(fname))
|
| 262 |
+
continue
|
| 263 |
+
if meta_signal['snr'][j] > self.config.snr_threshold:
|
| 264 |
+
tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
|
| 265 |
+
shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
|
| 266 |
+
tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
|
| 267 |
+
if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
|
| 268 |
+
continue
|
| 269 |
+
tmp_signal = tmp_signal / np.std(tmp_signal)
|
| 270 |
+
sample += tmp_signal / np.random.uniform(1, 5)
|
| 271 |
+
return sample
|
| 272 |
+
|
| 273 |
+
def thread_main(self, sess, n_threads=1, start=0):
|
| 274 |
+
stop = False
|
| 275 |
+
while not stop:
|
| 276 |
+
index = list(range(start, self.n_signal, n_threads))
|
| 277 |
+
np.random.shuffle(index)
|
| 278 |
+
for i in index:
|
| 279 |
+
fname_signal = os.path.join(self.signal_dir, self.signal.iloc[i]['fname'])
|
| 280 |
+
try:
|
| 281 |
+
if fname_signal not in self.buffer_signal:
|
| 282 |
+
meta = np.load(fname_signal)
|
| 283 |
+
data_FT = []
|
| 284 |
+
snr = []
|
| 285 |
+
for j in range(3):
|
| 286 |
+
tmp_data = meta['data'][..., j]
|
| 287 |
+
tmp_itp = meta['itp']
|
| 288 |
+
snr.append(self.get_snr(tmp_data, tmp_itp))
|
| 289 |
+
tmp_data -= np.mean(tmp_data)
|
| 290 |
+
f, t, tmp_FT = scipy.signal.stft(
|
| 291 |
+
tmp_data,
|
| 292 |
+
fs=self.config.fs,
|
| 293 |
+
nperseg=self.config.nperseg,
|
| 294 |
+
nfft=self.config.nfft,
|
| 295 |
+
boundary='zeros',
|
| 296 |
+
)
|
| 297 |
+
data_FT.append(tmp_FT)
|
| 298 |
+
data_FT = np.stack(data_FT, axis=-1)
|
| 299 |
+
self.buffer_signal[fname_signal] = {
|
| 300 |
+
'data_FT': data_FT,
|
| 301 |
+
'itp': tmp_itp,
|
| 302 |
+
'channels': meta['channels'],
|
| 303 |
+
'snr': snr,
|
| 304 |
+
}
|
| 305 |
+
meta_signal = self.buffer_signal[fname_signal]
|
| 306 |
+
except:
|
| 307 |
+
logging.error("Failed reading signal: {}".format(fname_signal))
|
| 308 |
+
continue
|
| 309 |
+
channels = meta_signal['channels'].tolist()
|
| 310 |
+
start_tp = meta_signal['itp'].tolist()
|
| 311 |
+
|
| 312 |
+
if channels not in self.buffer_channels_noise:
|
| 313 |
+
self.buffer_channels_noise[channels] = self.noise[self.noise['channels'] == channels]
|
| 314 |
+
fname_noise = os.path.join(
|
| 315 |
+
self.noise_dir, self.buffer_channels_noise[channels].sample(n=1).iloc[0]['fname']
|
| 316 |
+
)
|
| 317 |
+
try:
|
| 318 |
+
if fname_noise not in self.buffer_noise:
|
| 319 |
+
meta = np.load(fname_noise)
|
| 320 |
+
data_FT = []
|
| 321 |
+
for i in range(3):
|
| 322 |
+
tmp_data = meta['data'][: self.config.nt, i]
|
| 323 |
+
tmp_data -= np.mean(tmp_data)
|
| 324 |
+
f, t, tmp_FT = scipy.signal.stft(
|
| 325 |
+
tmp_data,
|
| 326 |
+
fs=self.config.fs,
|
| 327 |
+
nperseg=self.config.nperseg,
|
| 328 |
+
nfft=self.config.nfft,
|
| 329 |
+
boundary='zeros',
|
| 330 |
+
)
|
| 331 |
+
data_FT.append(tmp_FT)
|
| 332 |
+
data_FT = np.stack(data_FT, axis=-1)
|
| 333 |
+
self.buffer_noise[fname_noise] = {'data_FT': data_FT, 'channels': meta['channels']}
|
| 334 |
+
meta_noise = self.buffer_noise[fname_noise]
|
| 335 |
+
except:
|
| 336 |
+
logging.error("Failed reading noise: {}".format(fname_noise))
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
if self.coord.should_stop():
|
| 340 |
+
stop = True
|
| 341 |
+
break
|
| 342 |
+
|
| 343 |
+
j = np.random.choice([0, 1, 2])
|
| 344 |
+
if meta_signal['snr'][j] <= self.config.snr_threshold:
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
tmp_noise = meta_noise['data_FT'][..., j]
|
| 348 |
+
if np.isinf(tmp_noise).any() or np.isnan(tmp_noise).any() or (not np.any(tmp_noise)):
|
| 349 |
+
continue
|
| 350 |
+
tmp_noise = tmp_noise / np.std(tmp_noise)
|
| 351 |
+
|
| 352 |
+
tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
|
| 353 |
+
if np.random.random() < 0.9:
|
| 354 |
+
shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
|
| 355 |
+
tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
|
| 356 |
+
if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
|
| 357 |
+
continue
|
| 358 |
+
tmp_signal = tmp_signal / np.std(tmp_signal)
|
| 359 |
+
tmp_signal = self.add_event(tmp_signal, channels, j)
|
| 360 |
+
|
| 361 |
+
if np.random.random() < 0.2:
|
| 362 |
+
tmp_signal = np.fliplr(tmp_signal)
|
| 363 |
+
|
| 364 |
+
ratio = 0
|
| 365 |
+
while ratio <= 0:
|
| 366 |
+
ratio = self.config.noise_mean + np.random.randn() * self.config.noise_std
|
| 367 |
+
# ratio = np.random.uniform(self.config.noise_low, self.config.noise_high)
|
| 368 |
+
tmp_noisy_signal = tmp_signal + ratio * tmp_noise
|
| 369 |
+
noisy_signal = np.stack([tmp_noisy_signal.real, tmp_noisy_signal.imag], axis=-1)
|
| 370 |
+
if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any():
|
| 371 |
+
continue
|
| 372 |
+
noisy_signal = noisy_signal / np.std(noisy_signal)
|
| 373 |
+
tmp_mask = np.abs(tmp_signal) / (np.abs(tmp_signal) + np.abs(ratio * tmp_noise) + 1e-4)
|
| 374 |
+
tmp_mask[tmp_mask >= 1] = 1
|
| 375 |
+
tmp_mask[tmp_mask <= 0] = 0
|
| 376 |
+
mask = np.zeros([tmp_mask.shape[0], tmp_mask.shape[1], self.n_class])
|
| 377 |
+
mask[:, :, 0] = tmp_mask
|
| 378 |
+
mask[:, :, 1] = 1 - tmp_mask
|
| 379 |
+
sess.run(self.enqueue, feed_dict={self.sample_placeholder: noisy_signal, self.target_placeholder: mask})
|
| 380 |
+
|
| 381 |
+
def start_threads(self, sess, n_threads=8):
|
| 382 |
+
for i in range(n_threads):
|
| 383 |
+
thread = threading.Thread(target=self.thread_main, args=(sess, n_threads, i))
|
| 384 |
+
thread.daemon = True
|
| 385 |
+
thread.start()
|
| 386 |
+
self.threads.append(thread)
|
| 387 |
+
return self.threads
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class DataReader_test(DataReader):
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
signal_dir=None,
|
| 394 |
+
signal_list=None,
|
| 395 |
+
noise_dir=None,
|
| 396 |
+
noise_list=None,
|
| 397 |
+
queue_size=None,
|
| 398 |
+
coord=None,
|
| 399 |
+
config=Config(),
|
| 400 |
+
):
|
| 401 |
+
self.config = config
|
| 402 |
+
|
| 403 |
+
signal_list = pd.read_csv(signal_list, header=0)
|
| 404 |
+
noise_list = pd.read_csv(noise_list, header=0)
|
| 405 |
+
self.signal = signal_list
|
| 406 |
+
self.noise = noise_list
|
| 407 |
+
self.n_signal = len(self.signal)
|
| 408 |
+
|
| 409 |
+
self.signal_dir = signal_dir
|
| 410 |
+
self.noise_dir = noise_dir
|
| 411 |
+
|
| 412 |
+
self.X_shape = config.X_shape
|
| 413 |
+
self.Y_shape = config.Y_shape
|
| 414 |
+
self.n_class = config.n_class
|
| 415 |
+
|
| 416 |
+
self.coord = coord
|
| 417 |
+
self.threads = []
|
| 418 |
+
self.queue_size = queue_size
|
| 419 |
+
|
| 420 |
+
self.add_queue()
|
| 421 |
+
self.buffer_signal = {}
|
| 422 |
+
self.buffer_noise = {}
|
| 423 |
+
self.buffer_channels_signal = {}
|
| 424 |
+
self.buffer_channels_noise = {}
|
| 425 |
+
|
| 426 |
+
def add_queue(self):
|
| 427 |
+
self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 428 |
+
self.target_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 429 |
+
self.ratio_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 430 |
+
self.signal_placeholder = tf.compat.v1.placeholder(dtype=tf.complex64, shape=None)
|
| 431 |
+
self.noise_placeholder = tf.compat.v1.placeholder(dtype=tf.complex64, shape=None)
|
| 432 |
+
self.fname_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=None)
|
| 433 |
+
self.queue = tf.queue.PaddingFIFOQueue(
|
| 434 |
+
self.queue_size,
|
| 435 |
+
['float32', 'float32', 'float32', 'complex64', 'complex64', 'string'],
|
| 436 |
+
shapes=[
|
| 437 |
+
self.config.X_shape,
|
| 438 |
+
self.config.Y_shape,
|
| 439 |
+
[],
|
| 440 |
+
self.config.signal_shape,
|
| 441 |
+
self.config.noise_shape,
|
| 442 |
+
[],
|
| 443 |
+
],
|
| 444 |
+
)
|
| 445 |
+
self.enqueue = self.queue.enqueue(
|
| 446 |
+
[
|
| 447 |
+
self.sample_placeholder,
|
| 448 |
+
self.target_placeholder,
|
| 449 |
+
self.ratio_placeholder,
|
| 450 |
+
self.signal_placeholder,
|
| 451 |
+
self.noise_placeholder,
|
| 452 |
+
self.fname_placeholder,
|
| 453 |
+
]
|
| 454 |
+
)
|
| 455 |
+
return 0
|
| 456 |
+
|
| 457 |
+
def dequeue(self, num_elements):
|
| 458 |
+
output = self.queue.dequeue_up_to(num_elements)
|
| 459 |
+
return output
|
| 460 |
+
|
| 461 |
+
def thread_main(self, sess, n_threads=1, start=0):
|
| 462 |
+
index = list(range(start, self.n_signal, n_threads))
|
| 463 |
+
for i in index:
|
| 464 |
+
np.random.seed(i)
|
| 465 |
+
|
| 466 |
+
fname = self.signal.iloc[i]['fname']
|
| 467 |
+
fname_signal = os.path.join(self.signal_dir, fname)
|
| 468 |
+
meta = np.load(fname_signal)
|
| 469 |
+
data_FT = []
|
| 470 |
+
snr = []
|
| 471 |
+
for j in range(3):
|
| 472 |
+
tmp_data = meta['data'][..., j]
|
| 473 |
+
tmp_itp = meta['itp']
|
| 474 |
+
snr.append(self.get_snr(tmp_data, tmp_itp))
|
| 475 |
+
tmp_data -= np.mean(tmp_data)
|
| 476 |
+
f, t, tmp_FT = scipy.signal.stft(
|
| 477 |
+
tmp_data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
| 478 |
+
)
|
| 479 |
+
data_FT.append(tmp_FT)
|
| 480 |
+
data_FT = np.stack(data_FT, axis=-1)
|
| 481 |
+
meta_signal = {'data_FT': data_FT, 'itp': tmp_itp, 'channels': meta['channels'], 'snr': snr}
|
| 482 |
+
channels = meta['channels'].tolist()
|
| 483 |
+
start_tp = meta['itp'].tolist()
|
| 484 |
+
|
| 485 |
+
if channels not in self.buffer_channels_noise:
|
| 486 |
+
self.buffer_channels_noise[channels] = self.noise[self.noise['channels'] == channels]
|
| 487 |
+
fname_noise = os.path.join(
|
| 488 |
+
self.noise_dir, self.buffer_channels_noise[channels].sample(n=1, random_state=i).iloc[0]['fname']
|
| 489 |
+
)
|
| 490 |
+
meta = np.load(fname_noise)
|
| 491 |
+
data_FT = []
|
| 492 |
+
for i in range(3):
|
| 493 |
+
tmp_data = meta['data'][: self.config.nt, i]
|
| 494 |
+
tmp_data -= np.mean(tmp_data)
|
| 495 |
+
f, t, tmp_FT = scipy.signal.stft(
|
| 496 |
+
tmp_data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
| 497 |
+
)
|
| 498 |
+
data_FT.append(tmp_FT)
|
| 499 |
+
data_FT = np.stack(data_FT, axis=-1)
|
| 500 |
+
meta_noise = {'data_FT': data_FT, 'channels': meta['channels']}
|
| 501 |
+
|
| 502 |
+
if self.coord.should_stop():
|
| 503 |
+
stop = True
|
| 504 |
+
break
|
| 505 |
+
|
| 506 |
+
j = np.random.choice([0, 1, 2])
|
| 507 |
+
tmp_noise = meta_noise['data_FT'][..., j]
|
| 508 |
+
if np.isinf(tmp_noise).any() or np.isnan(tmp_noise).any() or (not np.any(tmp_noise)):
|
| 509 |
+
continue
|
| 510 |
+
tmp_noise = tmp_noise / np.std(tmp_noise)
|
| 511 |
+
|
| 512 |
+
tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
|
| 513 |
+
if np.random.random() < 0.9:
|
| 514 |
+
shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
|
| 515 |
+
tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
|
| 516 |
+
if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
|
| 517 |
+
continue
|
| 518 |
+
tmp_signal = tmp_signal / np.std(tmp_signal)
|
| 519 |
+
# tmp_signal = self.add_event(tmp_signal, channels, j)
|
| 520 |
+
# if np.random.random() < 0.2:
|
| 521 |
+
# tmp_signal = np.fliplr(tmp_signal)
|
| 522 |
+
|
| 523 |
+
ratio = 0
|
| 524 |
+
while ratio <= 0:
|
| 525 |
+
ratio = self.config.noise_mean + np.random.randn() * self.config.noise_std
|
| 526 |
+
tmp_noisy_signal = tmp_signal + ratio * tmp_noise
|
| 527 |
+
noisy_signal = np.stack([tmp_noisy_signal.real, tmp_noisy_signal.imag], axis=-1)
|
| 528 |
+
if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any():
|
| 529 |
+
continue
|
| 530 |
+
std_noisy_signal = np.std(noisy_signal)
|
| 531 |
+
noisy_signal = noisy_signal / std_noisy_signal
|
| 532 |
+
tmp_mask = np.abs(tmp_signal) / (np.abs(tmp_signal) + np.abs(ratio * tmp_noise) + 1e-4)
|
| 533 |
+
tmp_mask[tmp_mask >= 1] = 1
|
| 534 |
+
tmp_mask[tmp_mask <= 0] = 0
|
| 535 |
+
mask = np.zeros([tmp_mask.shape[0], tmp_mask.shape[1], self.n_class])
|
| 536 |
+
mask[:, :, 0] = tmp_mask
|
| 537 |
+
mask[:, :, 1] = 1 - tmp_mask
|
| 538 |
+
|
| 539 |
+
sess.run(
|
| 540 |
+
self.enqueue,
|
| 541 |
+
feed_dict={
|
| 542 |
+
self.sample_placeholder: noisy_signal,
|
| 543 |
+
self.target_placeholder: mask,
|
| 544 |
+
self.ratio_placeholder: std_noisy_signal,
|
| 545 |
+
self.signal_placeholder: tmp_signal,
|
| 546 |
+
self.noise_placeholder: ratio * tmp_noise,
|
| 547 |
+
self.fname_placeholder: fname,
|
| 548 |
+
},
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class DataReader_pred_queue(DataReader):
|
| 553 |
+
def __init__(self, signal_dir, signal_list, queue_size, coord, config=Config()):
|
| 554 |
+
self.config = config
|
| 555 |
+
signal_list = pd.read_csv(signal_list)
|
| 556 |
+
self.signal = signal_list
|
| 557 |
+
self.n_signal = len(self.signal)
|
| 558 |
+
self.n_class = config.n_class
|
| 559 |
+
self.X_shape = config.X_shape
|
| 560 |
+
self.Y_shape = config.Y_shape
|
| 561 |
+
self.signal_dir = signal_dir
|
| 562 |
+
|
| 563 |
+
self.coord = coord
|
| 564 |
+
self.threads = []
|
| 565 |
+
self.queue_size = queue_size
|
| 566 |
+
self.add_placeholder()
|
| 567 |
+
|
| 568 |
+
def add_placeholder(self):
|
| 569 |
+
self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 570 |
+
self.ratio_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
| 571 |
+
self.fname_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=None)
|
| 572 |
+
self.queue = tf.queue.PaddingFIFOQueue(
|
| 573 |
+
self.queue_size, ['float32', 'float32', 'string'], shapes=[self.config.X_shape, [], []]
|
| 574 |
+
)
|
| 575 |
+
self.enqueue = self.queue.enqueue([self.sample_placeholder, self.ratio_placeholder, self.fname_placeholder])
|
| 576 |
+
|
| 577 |
+
def dequeue(self, num_elements):
|
| 578 |
+
output = self.queue.dequeue_up_to(num_elements)
|
| 579 |
+
return output
|
| 580 |
+
|
| 581 |
+
def thread_main(self, sess, n_threads=1, start=0):
|
| 582 |
+
index = list(range(start, self.n_signal, n_threads))
|
| 583 |
+
shift = 0
|
| 584 |
+
for i in index:
|
| 585 |
+
fname = self.signal.iloc[i]['fname']
|
| 586 |
+
data_signal = np.load(os.path.join(self.signal_dir, fname))
|
| 587 |
+
f, t, tmp_signal = scipy.signal.stft(
|
| 588 |
+
scipy.signal.detrend(np.squeeze(data_signal['data'][shift : self.config.nt + shift])),
|
| 589 |
+
fs=self.config.fs,
|
| 590 |
+
nperseg=self.config.nperseg,
|
| 591 |
+
nfft=self.config.nfft,
|
| 592 |
+
boundary='zeros',
|
| 593 |
+
)
|
| 594 |
+
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1)
|
| 595 |
+
if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any() or (not np.any(noisy_signal)):
|
| 596 |
+
continue
|
| 597 |
+
std_noisy_signal = np.std(noisy_signal)
|
| 598 |
+
if std_noisy_signal == 0:
|
| 599 |
+
continue
|
| 600 |
+
noisy_signal = noisy_signal / std_noisy_signal
|
| 601 |
+
sess.run(
|
| 602 |
+
self.enqueue,
|
| 603 |
+
feed_dict={
|
| 604 |
+
self.sample_placeholder: noisy_signal,
|
| 605 |
+
self.ratio_placeholder: std_noisy_signal,
|
| 606 |
+
self.fname_placeholder: fname,
|
| 607 |
+
},
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class DataReader_pred:
|
| 612 |
+
def __init__(self, signal_dir, signal_list, format="numpy", sampling_rate=100, config=Config()):
|
| 613 |
+
self.buffer = {}
|
| 614 |
+
self.config = config
|
| 615 |
+
self.format = format
|
| 616 |
+
self.dtype = "float32"
|
| 617 |
+
try:
|
| 618 |
+
signal_list = pd.read_csv(signal_list, sep="\t")["fname"]
|
| 619 |
+
except:
|
| 620 |
+
signal_list = pd.read_csv(signal_list)["fname"]
|
| 621 |
+
self.signal_list = signal_list
|
| 622 |
+
self.n_signal = len(self.signal_list)
|
| 623 |
+
self.signal_dir = signal_dir
|
| 624 |
+
self.sampling_rate = sampling_rate
|
| 625 |
+
self.n_class = config.n_class
|
| 626 |
+
FT_shape = self.get_data_shape()
|
| 627 |
+
self.X_shape = [*FT_shape, 2]
|
| 628 |
+
|
| 629 |
+
def get_data_shape(self):
|
| 630 |
+
# fname = self.signal_list.iloc[0]['fname']
|
| 631 |
+
# data = np.load(os.path.join(self.signal_dir, fname), allow_pickle=True)["data"]
|
| 632 |
+
# data = np.squeeze(data)
|
| 633 |
+
base_name = self.signal_list[0]
|
| 634 |
+
if self.format == "numpy":
|
| 635 |
+
meta = self.read_numpy(os.path.join(self.signal_dir, base_name))
|
| 636 |
+
elif self.format == "mseed":
|
| 637 |
+
meta = self.read_mseed(os.path.join(self.signal_dir, base_name))
|
| 638 |
+
elif self.format == "hdf5":
|
| 639 |
+
meta = self.read_hdf5(base_name)
|
| 640 |
+
|
| 641 |
+
data = meta["data"]
|
| 642 |
+
data = np.transpose(data, [2, 1, 0])
|
| 643 |
+
|
| 644 |
+
if self.sampling_rate != 100:
|
| 645 |
+
t = np.linspace(0, 1, data.shape[-1])
|
| 646 |
+
t_interp = np.linspace(0, 1, np.int(np.around(data.shape[-1] * 100.0 / self.sampling_rate)))
|
| 647 |
+
data = interp1d(t, data, kind="slinear")(t_interp)
|
| 648 |
+
f, t, tmp_signal = scipy.signal.stft(
|
| 649 |
+
data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
| 650 |
+
)
|
| 651 |
+
logging.info(f"Input data shape: {tmp_signal.shape} measured on file {base_name}")
|
| 652 |
+
return tmp_signal.shape
|
| 653 |
+
|
| 654 |
+
def __len__(self):
|
| 655 |
+
return self.n_signal
|
| 656 |
+
|
| 657 |
+
def read_numpy(self, fname):
|
| 658 |
+
# try:
|
| 659 |
+
if fname not in self.buffer:
|
| 660 |
+
npz = np.load(fname)
|
| 661 |
+
meta = {}
|
| 662 |
+
if len(npz['data'].shape) == 1:
|
| 663 |
+
meta["data"] = npz['data'][:, np.newaxis, np.newaxis]
|
| 664 |
+
elif len(npz['data'].shape) == 2:
|
| 665 |
+
meta["data"] = npz['data'][:, np.newaxis, :]
|
| 666 |
+
else:
|
| 667 |
+
meta["data"] = npz['data']
|
| 668 |
+
if "p_idx" in npz.files:
|
| 669 |
+
if len(npz["p_idx"].shape) == 0:
|
| 670 |
+
meta["itp"] = [[npz["p_idx"]]]
|
| 671 |
+
else:
|
| 672 |
+
meta["itp"] = npz["p_idx"]
|
| 673 |
+
if "s_idx" in npz.files:
|
| 674 |
+
if len(npz["s_idx"].shape) == 0:
|
| 675 |
+
meta["its"] = [[npz["s_idx"]]]
|
| 676 |
+
else:
|
| 677 |
+
meta["its"] = npz["s_idx"]
|
| 678 |
+
if "t0" in npz.files:
|
| 679 |
+
meta["t0"] = npz["t0"]
|
| 680 |
+
self.buffer[fname] = meta
|
| 681 |
+
else:
|
| 682 |
+
meta = self.buffer[fname]
|
| 683 |
+
return meta
|
| 684 |
+
# except:
|
| 685 |
+
# logging.error("Failed reading {}".format(fname))
|
| 686 |
+
# return None
|
| 687 |
+
|
| 688 |
+
def read_hdf5(self, fname):
|
| 689 |
+
data = self.h5_data[fname][()]
|
| 690 |
+
attrs = self.h5_data[fname].attrs
|
| 691 |
+
meta = {}
|
| 692 |
+
if len(data.shape) == 2:
|
| 693 |
+
meta["data"] = data[:, np.newaxis, :]
|
| 694 |
+
else:
|
| 695 |
+
meta["data"] = data
|
| 696 |
+
if "p_idx" in attrs:
|
| 697 |
+
if len(attrs["p_idx"].shape) == 0:
|
| 698 |
+
meta["itp"] = [[attrs["p_idx"]]]
|
| 699 |
+
else:
|
| 700 |
+
meta["itp"] = attrs["p_idx"]
|
| 701 |
+
if "s_idx" in attrs:
|
| 702 |
+
if len(attrs["s_idx"].shape) == 0:
|
| 703 |
+
meta["its"] = [[attrs["s_idx"]]]
|
| 704 |
+
else:
|
| 705 |
+
meta["its"] = attrs["s_idx"]
|
| 706 |
+
if "t0" in attrs:
|
| 707 |
+
meta["t0"] = attrs["t0"]
|
| 708 |
+
return meta
|
| 709 |
+
|
| 710 |
+
def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
|
| 711 |
+
with self.s3fs.open(bucket + "/" + fname, 'rb') as fp:
|
| 712 |
+
if format == "numpy":
|
| 713 |
+
meta = self.read_numpy(fp)
|
| 714 |
+
elif format == "mseed":
|
| 715 |
+
meta = self.read_mseed(fp)
|
| 716 |
+
else:
|
| 717 |
+
raise (f"Format {format} not supported")
|
| 718 |
+
return meta
|
| 719 |
+
|
| 720 |
+
def read_mseed(self, fname):
|
| 721 |
+
|
| 722 |
+
mseed = obspy.read(fname)
|
| 723 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
| 724 |
+
mseed = mseed.merge(fill_value=0)
|
| 725 |
+
starttime = min([st.stats.starttime for st in mseed])
|
| 726 |
+
endtime = max([st.stats.endtime for st in mseed])
|
| 727 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
| 728 |
+
if mseed[0].stats.sampling_rate != self.sampling_rate:
|
| 729 |
+
logging.warning(f"Sampling rate {mseed[0].stats.sampling_rate} != {self.sampling_rate} Hz")
|
| 730 |
+
|
| 731 |
+
order = ['3', '2', '1', 'E', 'N', 'Z']
|
| 732 |
+
order = {key: i for i, key in enumerate(order)}
|
| 733 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
| 734 |
+
|
| 735 |
+
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 736 |
+
nt = len(mseed[0].data)
|
| 737 |
+
data = np.zeros([nt, 3], dtype=self.dtype)
|
| 738 |
+
ids = [x.get_id() for x in mseed]
|
| 739 |
+
if len(ids) == 3:
|
| 740 |
+
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
|
| 741 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
| 742 |
+
else:
|
| 743 |
+
if len(ids) > 3:
|
| 744 |
+
logging.warning(f"More than 3 channels {ids}!")
|
| 745 |
+
for jj, id in enumerate(ids):
|
| 746 |
+
j = comp2idx[id[-1]]
|
| 747 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
| 748 |
+
|
| 749 |
+
data = data[:, np.newaxis, :]
|
| 750 |
+
meta = {"data": data, "t0": t0}
|
| 751 |
+
return meta
|
| 752 |
+
|
| 753 |
+
def __getitem__(self, i):
|
| 754 |
+
# fname = self.signal.iloc[i]['fname']
|
| 755 |
+
# data = np.load(os.path.join(self.signal_dir, fname), allow_pickle=True)["data"]
|
| 756 |
+
# data = np.squeeze(data)
|
| 757 |
+
base_name = self.signal_list[i]
|
| 758 |
+
|
| 759 |
+
if self.format == "numpy":
|
| 760 |
+
meta = self.read_numpy(os.path.join(self.signal_dir, base_name))
|
| 761 |
+
elif self.format == "mseed":
|
| 762 |
+
meta = self.read_mseed(os.path.join(self.signal_dir, base_name))
|
| 763 |
+
elif self.format == "hdf5":
|
| 764 |
+
meta = self.read_hdf5(base_name)
|
| 765 |
+
|
| 766 |
+
data = meta["data"] # nt, 1, nch
|
| 767 |
+
data = np.transpose(data, [2, 1, 0]) # nch, 1, nt
|
| 768 |
+
if np.mod(data.shape[-1], 3000) == 1: # 3001=>3000
|
| 769 |
+
data = data[..., :-1]
|
| 770 |
+
if "t0" in meta:
|
| 771 |
+
t0 = meta["t0"]
|
| 772 |
+
else:
|
| 773 |
+
t0 = "1970-01-01T00:00:00.000"
|
| 774 |
+
|
| 775 |
+
if self.sampling_rate != 100:
|
| 776 |
+
logging.warning(f"Resample from {self.sampling_rate} to 100!")
|
| 777 |
+
t = np.linspace(0, 1, data.shape[-1])
|
| 778 |
+
t_interp = np.linspace(0, 1, np.int(np.around(data.shape[-1] * 100.0 / self.sampling_rate)))
|
| 779 |
+
data = interp1d(t, data, kind="slinear")(t_interp)
|
| 780 |
+
# sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
|
| 781 |
+
# data = scipy.signal.sosfilt(sos, data)
|
| 782 |
+
f, t, tmp_signal = scipy.signal.stft(
|
| 783 |
+
data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
| 784 |
+
) # nch, 1, nf, nt
|
| 785 |
+
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # nch, 1, nf, nt, 2
|
| 786 |
+
noisy_signal[np.isnan(noisy_signal)] = 0
|
| 787 |
+
noisy_signal[np.isinf(noisy_signal)] = 0
|
| 788 |
+
# noisy_signal, std_noisy_signal = normalize(noisy_signal)
|
| 789 |
+
# return noisy_signal.astype(self.dtype), std_noisy_signal.astype(self.dtype), fname
|
| 790 |
+
|
| 791 |
+
return noisy_signal.astype(self.dtype), base_name, t0
|
| 792 |
+
|
| 793 |
+
def dataset(self, batch_size, num_parallel_calls=4):
|
| 794 |
+
dataset = dataset_map(
|
| 795 |
+
self,
|
| 796 |
+
output_types=(self.dtype, "string", "string"),
|
| 797 |
+
output_shapes=(self.X_shape, None, None),
|
| 798 |
+
num_parallel_calls=num_parallel_calls,
|
| 799 |
+
)
|
| 800 |
+
dataset = tf.compat.v1.data.make_one_shot_iterator(
|
| 801 |
+
dataset.batch(batch_size).prefetch(batch_size * 3)
|
| 802 |
+
).get_next()
|
| 803 |
+
return dataset
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
if __name__ == "__main__":
|
| 807 |
+
|
| 808 |
+
# %%
|
| 809 |
+
data_reader = DataReader_pred(signal_dir="./Dataset/yixiao/", signal_list="./Dataset/yixiao.csv")
|
| 810 |
+
noisy_signal, std_noisy_signal, fname = data_reader[0]
|
| 811 |
+
print(noisy_signal.shape, std_noisy_signal.shape, fname)
|
| 812 |
+
batch = data_reader.dataset(10)
|
| 813 |
+
init = tf.compat.v1.initialize_all_variables()
|
| 814 |
+
sess = tf.compat.v1.Session()
|
| 815 |
+
sess.run(init)
|
| 816 |
+
print(sess.run(batch))
|
deepdenoiser/model.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
|
| 6 |
+
from util import *
|
| 7 |
+
|
| 8 |
+
tf.compat.v1.disable_eager_execution()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ModelConfig:
|
| 12 |
+
|
| 13 |
+
batch_size = 20
|
| 14 |
+
depths = 6
|
| 15 |
+
filters_root = 8
|
| 16 |
+
kernel_size = [3, 3]
|
| 17 |
+
pool_size = [2, 2]
|
| 18 |
+
dilation_rate = [1, 1]
|
| 19 |
+
class_weights = [1.0, 1.0, 1.0]
|
| 20 |
+
loss_type = "cross_entropy"
|
| 21 |
+
weight_decay = 0.0
|
| 22 |
+
optimizer = "adam"
|
| 23 |
+
momentum = 0.9
|
| 24 |
+
learning_rate = 0.01
|
| 25 |
+
decay_step = 1e9
|
| 26 |
+
decay_rate = 0.9
|
| 27 |
+
drop_rate = 0.0
|
| 28 |
+
summary = True
|
| 29 |
+
|
| 30 |
+
X_shape = [31, 201, 2]
|
| 31 |
+
n_channel = X_shape[-1]
|
| 32 |
+
Y_shape = [31, 201, 2]
|
| 33 |
+
n_class = Y_shape[-1]
|
| 34 |
+
|
| 35 |
+
def __init__(self, **kwargs):
|
| 36 |
+
for k, v in kwargs.items():
|
| 37 |
+
setattr(self, k, v)
|
| 38 |
+
|
| 39 |
+
def update_args(self, args):
|
| 40 |
+
for k, v in vars(args).items():
|
| 41 |
+
setattr(self, k, v)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def crop_and_concat(net1, net2):
|
| 45 |
+
"""
|
| 46 |
+
the size(net1) <= size(net2)
|
| 47 |
+
"""
|
| 48 |
+
# net1_shape = net1.get_shape().as_list()
|
| 49 |
+
# net2_shape = net2.get_shape().as_list()
|
| 50 |
+
# # print(net1_shape)
|
| 51 |
+
# # print(net2_shape)
|
| 52 |
+
# # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 53 |
+
# offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 54 |
+
# size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 55 |
+
# net2_resize = tf.slice(net2, offsets, size)
|
| 56 |
+
# return tf.concat([net1, net2_resize], 3)
|
| 57 |
+
# # else:
|
| 58 |
+
# # offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
|
| 59 |
+
# # size = [-1, net2_shape[1], net2_shape[2], -1]
|
| 60 |
+
# # net1_resize = tf.slice(net1, offsets, size)
|
| 61 |
+
# # return tf.concat([net1_resize, net2], 3)
|
| 62 |
+
|
| 63 |
+
## dynamic shape
|
| 64 |
+
chn1 = net1.get_shape().as_list()[-1]
|
| 65 |
+
chn2 = net2.get_shape().as_list()[-1]
|
| 66 |
+
net1_shape = tf.shape(net1)
|
| 67 |
+
net2_shape = tf.shape(net2)
|
| 68 |
+
# print(net1_shape)
|
| 69 |
+
# print(net2_shape)
|
| 70 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 71 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 72 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 73 |
+
net2_resize = tf.slice(net2, offsets, size)
|
| 74 |
+
|
| 75 |
+
out = tf.concat([net1, net2_resize], 3)
|
| 76 |
+
out.set_shape([None, None, None, chn1 + chn2])
|
| 77 |
+
return out
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def crop_only(net1, net2):
|
| 81 |
+
"""
|
| 82 |
+
the size(net1) <= size(net2)
|
| 83 |
+
"""
|
| 84 |
+
net1_shape = net1.get_shape().as_list()
|
| 85 |
+
net2_shape = net2.get_shape().as_list()
|
| 86 |
+
# print(net1_shape)
|
| 87 |
+
# print(net2_shape)
|
| 88 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 89 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 90 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 91 |
+
net2_resize = tf.slice(net2, offsets, size)
|
| 92 |
+
# return tf.concat([net1, net2_resize], 3)
|
| 93 |
+
return net2_resize
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class UNet:
|
| 97 |
+
def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
|
| 98 |
+
self.depths = config.depths
|
| 99 |
+
self.filters_root = config.filters_root
|
| 100 |
+
self.kernel_size = config.kernel_size
|
| 101 |
+
self.dilation_rate = config.dilation_rate
|
| 102 |
+
self.pool_size = config.pool_size
|
| 103 |
+
self.X_shape = config.X_shape
|
| 104 |
+
self.Y_shape = config.Y_shape
|
| 105 |
+
self.n_channel = config.n_channel
|
| 106 |
+
self.n_class = config.n_class
|
| 107 |
+
self.class_weights = config.class_weights
|
| 108 |
+
self.batch_size = config.batch_size
|
| 109 |
+
self.loss_type = config.loss_type
|
| 110 |
+
self.weight_decay = config.weight_decay
|
| 111 |
+
self.optimizer = config.optimizer
|
| 112 |
+
self.decay_step = config.decay_step
|
| 113 |
+
self.decay_rate = config.decay_rate
|
| 114 |
+
self.momentum = config.momentum
|
| 115 |
+
self.learning_rate = config.learning_rate
|
| 116 |
+
self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
|
| 117 |
+
self.summary_train = []
|
| 118 |
+
self.summary_valid = []
|
| 119 |
+
|
| 120 |
+
self.build(input_batch, mode=mode)
|
| 121 |
+
|
| 122 |
+
def add_placeholders(self, input_batch=None, mode='train'):
|
| 123 |
+
if input_batch is None:
|
| 124 |
+
self.X = tf.compat.v1.placeholder(
|
| 125 |
+
dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X'
|
| 126 |
+
)
|
| 127 |
+
self.Y = tf.compat.v1.placeholder(
|
| 128 |
+
dtype=tf.float32, shape=[None, None, None, self.n_class], name='y'
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
self.X = input_batch[0]
|
| 132 |
+
if mode in ["train", "valid", "test"]:
|
| 133 |
+
self.Y = input_batch[1]
|
| 134 |
+
self.input_batch = input_batch
|
| 135 |
+
|
| 136 |
+
self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
|
| 137 |
+
# self.keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob")
|
| 138 |
+
self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
|
| 139 |
+
# self.learning_rate = tf.placeholder_with_default(tf.constant(0.01, dtype=tf.float32), shape=[], name="learning_rate")
|
| 140 |
+
# self.global_step = tf.placeholder_with_default(tf.constant(0, dtype=tf.int32), shape=[], name="global_step")
|
| 141 |
+
|
| 142 |
+
def add_prediction_op(self):
|
| 143 |
+
logging.info(
|
| 144 |
+
"Model: depths {depths}, filters {filters}, "
|
| 145 |
+
"filter size {kernel_size[0]}x{kernel_size[1]}, "
|
| 146 |
+
"pool size: {pool_size[0]}x{pool_size[1]}, "
|
| 147 |
+
"dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
|
| 148 |
+
depths=self.depths,
|
| 149 |
+
filters=self.filters_root,
|
| 150 |
+
kernel_size=self.kernel_size,
|
| 151 |
+
dilation_rate=self.dilation_rate,
|
| 152 |
+
pool_size=self.pool_size,
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if self.weight_decay > 0:
|
| 157 |
+
weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
|
| 158 |
+
self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
|
| 159 |
+
else:
|
| 160 |
+
self.regularizer = None
|
| 161 |
+
|
| 162 |
+
self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(
|
| 163 |
+
scale=1.0, mode="fan_avg", distribution="uniform"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# down sample layers
|
| 167 |
+
convs = [None] * self.depths # store output of each depth
|
| 168 |
+
|
| 169 |
+
with tf.compat.v1.variable_scope("Input"):
|
| 170 |
+
net = self.X
|
| 171 |
+
net = tf.compat.v1.layers.conv2d(
|
| 172 |
+
net,
|
| 173 |
+
filters=self.filters_root,
|
| 174 |
+
kernel_size=self.kernel_size,
|
| 175 |
+
activation=None,
|
| 176 |
+
use_bias=False,
|
| 177 |
+
padding='same',
|
| 178 |
+
dilation_rate=self.dilation_rate,
|
| 179 |
+
kernel_initializer=self.initializer,
|
| 180 |
+
kernel_regularizer=self.regularizer,
|
| 181 |
+
# bias_regularizer=self.regularizer,
|
| 182 |
+
name="input_conv",
|
| 183 |
+
)
|
| 184 |
+
net = tf.compat.v1.layers.batch_normalization(net, training=self.is_training, name="input_bn")
|
| 185 |
+
net = tf.nn.relu(net, name="input_relu")
|
| 186 |
+
# net = tf.nn.dropout(net, self.keep_prob)
|
| 187 |
+
net = tf.compat.v1.layers.dropout(net, rate=self.drop_rate, training=self.is_training, name="input_dropout")
|
| 188 |
+
|
| 189 |
+
for depth in range(0, self.depths):
|
| 190 |
+
with tf.compat.v1.variable_scope("DownConv_%d" % depth):
|
| 191 |
+
filters = int(2 ** (depth) * self.filters_root)
|
| 192 |
+
|
| 193 |
+
net = tf.compat.v1.layers.conv2d(
|
| 194 |
+
net,
|
| 195 |
+
filters=filters,
|
| 196 |
+
kernel_size=self.kernel_size,
|
| 197 |
+
activation=None,
|
| 198 |
+
use_bias=False,
|
| 199 |
+
padding='same',
|
| 200 |
+
dilation_rate=self.dilation_rate,
|
| 201 |
+
kernel_initializer=self.initializer,
|
| 202 |
+
kernel_regularizer=self.regularizer,
|
| 203 |
+
# bias_regularizer=self.regularizer,
|
| 204 |
+
name="down_conv1_{}".format(depth + 1),
|
| 205 |
+
)
|
| 206 |
+
net = tf.compat.v1.layers.batch_normalization(
|
| 207 |
+
net, training=self.is_training, name="down_bn1_{}".format(depth + 1)
|
| 208 |
+
)
|
| 209 |
+
net = tf.nn.relu(net, name="down_relu1_{}".format(depth + 1))
|
| 210 |
+
net = tf.compat.v1.layers.dropout(
|
| 211 |
+
net, rate=self.drop_rate, training=self.is_training, name="down_dropout1_{}".format(depth + 1)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
convs[depth] = net
|
| 215 |
+
|
| 216 |
+
if depth < self.depths - 1:
|
| 217 |
+
net = tf.compat.v1.layers.conv2d(
|
| 218 |
+
net,
|
| 219 |
+
filters=filters,
|
| 220 |
+
kernel_size=self.kernel_size,
|
| 221 |
+
strides=self.pool_size,
|
| 222 |
+
activation=None,
|
| 223 |
+
use_bias=False,
|
| 224 |
+
padding='same',
|
| 225 |
+
# dilation_rate=self.dilation_rate,
|
| 226 |
+
kernel_initializer=self.initializer,
|
| 227 |
+
kernel_regularizer=self.regularizer,
|
| 228 |
+
# bias_regularizer=self.regularizer,
|
| 229 |
+
name="down_conv3_{}".format(depth + 1),
|
| 230 |
+
)
|
| 231 |
+
net = tf.compat.v1.layers.batch_normalization(
|
| 232 |
+
net, training=self.is_training, name="down_bn3_{}".format(depth + 1)
|
| 233 |
+
)
|
| 234 |
+
net = tf.nn.relu(net, name="down_relu3_{}".format(depth + 1))
|
| 235 |
+
net = tf.compat.v1.layers.dropout(
|
| 236 |
+
net, rate=self.drop_rate, training=self.is_training, name="down_dropout3_{}".format(depth + 1)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# up layers
|
| 240 |
+
for depth in range(self.depths - 2, -1, -1):
|
| 241 |
+
with tf.compat.v1.variable_scope("UpConv_%d" % depth):
|
| 242 |
+
filters = int(2 ** (depth) * self.filters_root)
|
| 243 |
+
net = tf.compat.v1.layers.conv2d_transpose(
|
| 244 |
+
net,
|
| 245 |
+
filters=filters,
|
| 246 |
+
kernel_size=self.kernel_size,
|
| 247 |
+
strides=self.pool_size,
|
| 248 |
+
activation=None,
|
| 249 |
+
use_bias=False,
|
| 250 |
+
padding="same",
|
| 251 |
+
kernel_initializer=self.initializer,
|
| 252 |
+
kernel_regularizer=self.regularizer,
|
| 253 |
+
# bias_regularizer=self.regularizer,
|
| 254 |
+
name="up_conv0_{}".format(depth + 1),
|
| 255 |
+
)
|
| 256 |
+
net = tf.compat.v1.layers.batch_normalization(
|
| 257 |
+
net, training=self.is_training, name="up_bn0_{}".format(depth + 1)
|
| 258 |
+
)
|
| 259 |
+
net = tf.nn.relu(net, name="up_relu0_{}".format(depth + 1))
|
| 260 |
+
net = tf.compat.v1.layers.dropout(
|
| 261 |
+
net, rate=self.drop_rate, training=self.is_training, name="up_dropout0_{}".format(depth + 1)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# skip connection
|
| 265 |
+
net = crop_and_concat(convs[depth], net)
|
| 266 |
+
# net = crop_only(convs[depth], net)
|
| 267 |
+
|
| 268 |
+
net = tf.compat.v1.layers.conv2d(
|
| 269 |
+
net,
|
| 270 |
+
filters=filters,
|
| 271 |
+
kernel_size=self.kernel_size,
|
| 272 |
+
activation=None,
|
| 273 |
+
use_bias=False,
|
| 274 |
+
padding='same',
|
| 275 |
+
dilation_rate=self.dilation_rate,
|
| 276 |
+
kernel_initializer=self.initializer,
|
| 277 |
+
kernel_regularizer=self.regularizer,
|
| 278 |
+
# bias_regularizer=self.regularizer,
|
| 279 |
+
name="up_conv1_{}".format(depth + 1),
|
| 280 |
+
)
|
| 281 |
+
net = tf.compat.v1.layers.batch_normalization(
|
| 282 |
+
net, training=self.is_training, name="up_bn1_{}".format(depth + 1)
|
| 283 |
+
)
|
| 284 |
+
net = tf.nn.relu(net, name="up_relu1_{}".format(depth + 1))
|
| 285 |
+
net = tf.compat.v1.layers.dropout(
|
| 286 |
+
net, rate=self.drop_rate, training=self.is_training, name="up_dropout1_{}".format(depth + 1)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Output Map
|
| 290 |
+
with tf.compat.v1.variable_scope("Output"):
|
| 291 |
+
net = tf.compat.v1.layers.conv2d(
|
| 292 |
+
net,
|
| 293 |
+
filters=self.n_class,
|
| 294 |
+
kernel_size=(1, 1),
|
| 295 |
+
activation=None,
|
| 296 |
+
use_bias=True,
|
| 297 |
+
padding='same',
|
| 298 |
+
# dilation_rate=self.dilation_rate,
|
| 299 |
+
kernel_initializer=self.initializer,
|
| 300 |
+
kernel_regularizer=self.regularizer,
|
| 301 |
+
# bias_regularizer=self.regularizer,
|
| 302 |
+
name="output_conv",
|
| 303 |
+
)
|
| 304 |
+
# net = tf.nn.relu(net,
|
| 305 |
+
# name="output_relu")
|
| 306 |
+
# net = tf.layers.dropout(net,
|
| 307 |
+
# rate=self.drop_rate,
|
| 308 |
+
# training=self.is_training,
|
| 309 |
+
# name="output_dropout")
|
| 310 |
+
# net = tf.layers.batch_normalization(net,
|
| 311 |
+
# training=self.is_training,
|
| 312 |
+
# name="output_bn")
|
| 313 |
+
output = net
|
| 314 |
+
|
| 315 |
+
with tf.compat.v1.variable_scope("representation"):
|
| 316 |
+
self.representation = convs[-1]
|
| 317 |
+
|
| 318 |
+
with tf.compat.v1.variable_scope("logits"):
|
| 319 |
+
self.logits = output
|
| 320 |
+
tmp = tf.compat.v1.summary.histogram("logits", self.logits)
|
| 321 |
+
self.summary_train.append(tmp)
|
| 322 |
+
|
| 323 |
+
with tf.compat.v1.variable_scope("preds"):
|
| 324 |
+
self.preds = tf.nn.softmax(output)
|
| 325 |
+
tmp = tf.compat.v1.summary.histogram("preds", self.preds)
|
| 326 |
+
self.summary_train.append(tmp)
|
| 327 |
+
|
| 328 |
+
def add_loss_op(self):
|
| 329 |
+
if self.loss_type == "cross_entropy":
|
| 330 |
+
with tf.compat.v1.variable_scope("cross_entropy"):
|
| 331 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
| 332 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
| 333 |
+
if (np.array(self.class_weights) != 1).any():
|
| 334 |
+
class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
|
| 335 |
+
weight_map = tf.multiply(flat_labels, class_weights)
|
| 336 |
+
weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
|
| 337 |
+
loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
|
| 338 |
+
# loss_map = tf.nn.sigmoid_cross_entropy_with_logits(logits=flat_logits,
|
| 339 |
+
# labels=flat_labels)
|
| 340 |
+
weighted_loss = tf.multiply(loss_map, weight_map)
|
| 341 |
+
loss = tf.reduce_mean(input_tensor=weighted_loss)
|
| 342 |
+
else:
|
| 343 |
+
loss = tf.reduce_mean(
|
| 344 |
+
input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
|
| 345 |
+
)
|
| 346 |
+
# loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=flat_logits,
|
| 347 |
+
# labels=flat_labels))
|
| 348 |
+
elif self.loss_type == "IOU":
|
| 349 |
+
with tf.compat.v1.variable_scope("IOU"):
|
| 350 |
+
eps = 1e-7
|
| 351 |
+
loss = 0
|
| 352 |
+
for i in range(1, self.n_class):
|
| 353 |
+
intersection = eps + tf.reduce_sum(
|
| 354 |
+
input_tensor=self.preds[:, :, :, i] * self.Y[:, :, :, i], axis=[1, 2]
|
| 355 |
+
)
|
| 356 |
+
union = (
|
| 357 |
+
eps
|
| 358 |
+
+ tf.reduce_sum(input_tensor=self.preds[:, :, :, i], axis=[1, 2])
|
| 359 |
+
+ tf.reduce_sum(input_tensor=self.Y[:, :, :, i], axis=[1, 2])
|
| 360 |
+
)
|
| 361 |
+
loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
|
| 362 |
+
elif self.loss_type == "mean_squared":
|
| 363 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
| 364 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
| 365 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
| 366 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
| 367 |
+
loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
|
| 368 |
+
else:
|
| 369 |
+
raise ValueError("Unknown loss function: " % self.loss_type)
|
| 370 |
+
|
| 371 |
+
tmp = tf.compat.v1.summary.scalar("train_loss", loss)
|
| 372 |
+
self.summary_train.append(tmp)
|
| 373 |
+
tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
|
| 374 |
+
self.summary_valid.append(tmp)
|
| 375 |
+
|
| 376 |
+
if self.weight_decay > 0:
|
| 377 |
+
with tf.compat.v1.name_scope('weight_loss'):
|
| 378 |
+
tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
| 379 |
+
weight_loss = tf.add_n(tmp, name="weight_loss")
|
| 380 |
+
self.loss = loss + weight_loss
|
| 381 |
+
else:
|
| 382 |
+
self.loss = loss
|
| 383 |
+
|
| 384 |
+
def add_training_op(self):
|
| 385 |
+
if self.optimizer == "momentum":
|
| 386 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
| 387 |
+
learning_rate=self.learning_rate,
|
| 388 |
+
global_step=self.global_step,
|
| 389 |
+
decay_steps=self.decay_step,
|
| 390 |
+
decay_rate=self.decay_rate,
|
| 391 |
+
staircase=True,
|
| 392 |
+
)
|
| 393 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(
|
| 394 |
+
learning_rate=self.learning_rate_node, momentum=self.momentum
|
| 395 |
+
)
|
| 396 |
+
elif self.optimizer == "adam":
|
| 397 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
| 398 |
+
learning_rate=self.learning_rate,
|
| 399 |
+
global_step=self.global_step,
|
| 400 |
+
decay_steps=self.decay_step,
|
| 401 |
+
decay_rate=self.decay_rate,
|
| 402 |
+
staircase=True,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
| 406 |
+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
|
| 407 |
+
with tf.control_dependencies(update_ops):
|
| 408 |
+
self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
|
| 409 |
+
tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
|
| 410 |
+
self.summary_train.append(tmp)
|
| 411 |
+
|
| 412 |
+
def reset_learning_rate(self, sess, learning_rate, global_step):
|
| 413 |
+
self.learning_rate = learning_rate
|
| 414 |
+
assign_op = self.global_step.assign(global_step)
|
| 415 |
+
sess.run(assign_op)
|
| 416 |
+
if self.optimizer == "momentum":
|
| 417 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
| 418 |
+
learning_rate=learning_rate,
|
| 419 |
+
global_step=self.global_step,
|
| 420 |
+
decay_steps=self.decay_step,
|
| 421 |
+
decay_rate=self.decay_rate,
|
| 422 |
+
staircase=True,
|
| 423 |
+
)
|
| 424 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(
|
| 425 |
+
learning_rate=self.learning_rate_node, momentum=self.momentum
|
| 426 |
+
)
|
| 427 |
+
elif self.optimizer == "adam":
|
| 428 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
| 429 |
+
learning_rate=self.learning_rate,
|
| 430 |
+
global_step=self.global_step,
|
| 431 |
+
decay_steps=self.decay_step,
|
| 432 |
+
decay_rate=self.decay_rate,
|
| 433 |
+
staircase=True,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
| 437 |
+
|
| 438 |
+
def train_on_batch(self, sess, X_batch, Y_batch, summary_writer, drop_rate=0.0):
|
| 439 |
+
feed = {self.drop_rate: drop_rate, self.is_training: True, self.X: X_batch, self.Y: Y_batch}
|
| 440 |
+
_, step_summary, step, loss = sess.run(
|
| 441 |
+
[self.train_op, self.summary_train, self.global_step, self.loss], feed_dict=feed
|
| 442 |
+
)
|
| 443 |
+
summary_writer.add_summary(step_summary, step)
|
| 444 |
+
return loss
|
| 445 |
+
|
| 446 |
+
def valid_on_batch(self, sess, X_batch, Y_batch, summary_writer, drop_rate=0.0):
|
| 447 |
+
feed = {self.drop_rate: drop_rate, self.is_training: False, self.X: X_batch, self.Y: Y_batch}
|
| 448 |
+
step_summary, step, loss, preds = sess.run(
|
| 449 |
+
[self.summary_valid, self.global_step, self.loss, self.preds], feed_dict=feed
|
| 450 |
+
)
|
| 451 |
+
summary_writer.add_summary(step_summary, step)
|
| 452 |
+
return loss, preds
|
| 453 |
+
|
| 454 |
+
def test_on_batch(self, sess, summary_writer):
|
| 455 |
+
feed = {self.drop_rate: 0, self.is_training: False}
|
| 456 |
+
(
|
| 457 |
+
step_summary,
|
| 458 |
+
step,
|
| 459 |
+
loss,
|
| 460 |
+
preds,
|
| 461 |
+
X_batch,
|
| 462 |
+
Y_batch,
|
| 463 |
+
ratio_batch,
|
| 464 |
+
signal_batch,
|
| 465 |
+
noise_batch,
|
| 466 |
+
fname_batch,
|
| 467 |
+
) = sess.run(
|
| 468 |
+
[
|
| 469 |
+
self.summary_valid,
|
| 470 |
+
self.global_step,
|
| 471 |
+
self.loss,
|
| 472 |
+
self.preds,
|
| 473 |
+
self.X,
|
| 474 |
+
self.Y,
|
| 475 |
+
self.input_batch[2],
|
| 476 |
+
self.input_batch[3],
|
| 477 |
+
self.input_batch[4],
|
| 478 |
+
self.input_batch[5],
|
| 479 |
+
],
|
| 480 |
+
feed_dict=feed,
|
| 481 |
+
)
|
| 482 |
+
summary_writer.add_summary(step_summary, step)
|
| 483 |
+
|
| 484 |
+
return loss, preds, X_batch, Y_batch, ratio_batch, signal_batch, noise_batch, fname_batch
|
| 485 |
+
|
| 486 |
+
def build(self, input_batch=None, mode='train'):
|
| 487 |
+
self.add_placeholders(input_batch, mode)
|
| 488 |
+
self.add_prediction_op()
|
| 489 |
+
if mode in ["train", "valid", "test"]:
|
| 490 |
+
self.add_loss_op()
|
| 491 |
+
self.add_training_op()
|
| 492 |
+
# self.add_metrics_op()
|
| 493 |
+
self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
|
| 494 |
+
self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
|
| 495 |
+
return 0
|
deepdenoiser/predict.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from data_reader import DataReader_pred, normalize_batch
|
| 13 |
+
from model import UNet
|
| 14 |
+
from util import *
|
| 15 |
+
|
| 16 |
+
tf.compat.v1.disable_eager_execution()
|
| 17 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def read_args():
|
| 21 |
+
"""Returns args"""
|
| 22 |
+
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
|
| 25 |
+
parser.add_argument("--format", default="numpy", type=str, help="Input data format: numpy or mseed")
|
| 26 |
+
parser.add_argument("--batch_size", default=20, type=int, help="Batch size")
|
| 27 |
+
parser.add_argument("--output_dir", default="output", help="Output directory (default: output)")
|
| 28 |
+
parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
|
| 29 |
+
parser.add_argument("--sampling_rate", default=100, type=int, help="sampling rate of pred data")
|
| 30 |
+
parser.add_argument("--data_dir", default="./Dataset/pred/", help="Input file directory")
|
| 31 |
+
parser.add_argument("--data_list", default="./Dataset/pred.csv", help="Input csv file")
|
| 32 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure")
|
| 33 |
+
parser.add_argument("--save_signal", action="store_true", help="If save denoised signal")
|
| 34 |
+
parser.add_argument("--save_noise", action="store_true", help="If save denoised noise")
|
| 35 |
+
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None):
|
| 41 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 42 |
+
if log_dir is None:
|
| 43 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
| 44 |
+
logging.info("Pred log: %s" % log_dir)
|
| 45 |
+
# logging.info("Dataset size: {}".format(data_reader.num_data))
|
| 46 |
+
if not os.path.exists(log_dir):
|
| 47 |
+
os.makedirs(log_dir)
|
| 48 |
+
if args.plot_figure:
|
| 49 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
| 50 |
+
os.makedirs(figure_dir, exist_ok=True)
|
| 51 |
+
if args.save_signal or args.save_noise:
|
| 52 |
+
result_dir = os.path.join(log_dir, 'results')
|
| 53 |
+
os.makedirs(result_dir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 56 |
+
data_batch = data_reader.dataset(args.batch_size)
|
| 57 |
+
|
| 58 |
+
# model = UNet(input_batch=data_batch, mode='pred')
|
| 59 |
+
model = UNet(mode='pred')
|
| 60 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 61 |
+
sess_config.gpu_options.allow_growth = True
|
| 62 |
+
# sess_config.log_device_placement = False
|
| 63 |
+
|
| 64 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 65 |
+
|
| 66 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 67 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 68 |
+
sess.run(init)
|
| 69 |
+
|
| 70 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 71 |
+
logging.info(f"restoring models: {latest_check_point}")
|
| 72 |
+
saver.restore(sess, latest_check_point)
|
| 73 |
+
|
| 74 |
+
if args.plot_figure:
|
| 75 |
+
num_pool = multiprocessing.cpu_count()
|
| 76 |
+
else:
|
| 77 |
+
num_pool = 2
|
| 78 |
+
multiprocessing.set_start_method('spawn')
|
| 79 |
+
pool = multiprocessing.Pool(num_pool)
|
| 80 |
+
for _ in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"):
|
| 81 |
+
X_batch, fname_batch, t0_batch = sess.run(data_batch)
|
| 82 |
+
nbt, nch, nst, nf, nt, nimg = X_batch.shape
|
| 83 |
+
X_batch_ = np.reshape(X_batch, [nbt * nch * nst, nf, nt, nimg])
|
| 84 |
+
X_batch_ = normalize_batch(X_batch_)
|
| 85 |
+
preds_batch = sess.run(
|
| 86 |
+
model.preds,
|
| 87 |
+
feed_dict={model.X: X_batch_, model.drop_rate: 0, model.is_training: False},
|
| 88 |
+
)
|
| 89 |
+
preds_batch = np.reshape(preds_batch, [nbt, nch, nst, nf, nt, preds_batch.shape[-1]])
|
| 90 |
+
# preds_batch, X_batch, ratio_batch, fname_batch = sess.run(
|
| 91 |
+
# [model.preds, data_batch[0], data_batch[1], data_batch[2]],
|
| 92 |
+
# feed_dict={model.drop_rate: 0, model.is_training: False},
|
| 93 |
+
# )
|
| 94 |
+
|
| 95 |
+
if args.save_signal or args.save_noise:
|
| 96 |
+
save_results(
|
| 97 |
+
preds_batch,
|
| 98 |
+
X_batch,
|
| 99 |
+
fname=[x.decode() for x in fname_batch],
|
| 100 |
+
t0=[x.decode() for x in t0_batch],
|
| 101 |
+
save_signal=args.save_signal,
|
| 102 |
+
save_noise=args.save_noise,
|
| 103 |
+
result_dir=result_dir,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if args.plot_figure:
|
| 107 |
+
pool.starmap(
|
| 108 |
+
partial(
|
| 109 |
+
plot_figures,
|
| 110 |
+
figure_dir=figure_dir,
|
| 111 |
+
),
|
| 112 |
+
zip(preds_batch, X_batch, [x.decode() for x in fname_batch]),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
pool.close()
|
| 116 |
+
|
| 117 |
+
return 0
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def main(args):
|
| 121 |
+
|
| 122 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
| 123 |
+
|
| 124 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 125 |
+
data_reader = DataReader_pred(
|
| 126 |
+
format=args.format, signal_dir=args.data_dir, signal_list=args.data_list, sampling_rate=args.sampling_rate
|
| 127 |
+
)
|
| 128 |
+
logging.info("Dataset Size: {}".format(data_reader.n_signal))
|
| 129 |
+
pred_fn(args, data_reader, log_dir=args.output_dir)
|
| 130 |
+
|
| 131 |
+
return 0
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == '__main__':
|
| 135 |
+
args = read_args()
|
| 136 |
+
main(args)
|
deepdenoiser/train.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import warnings
|
| 2 |
+
#warnings.filterwarnings('ignore', category=FutureWarning)
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
tf.compat.v1.disable_eager_execution()
|
| 6 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from model import UNet
|
| 12 |
+
from data_reader import *
|
| 13 |
+
from util import *
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import multiprocessing
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def read_args():
|
| 20 |
+
"""Returns args"""
|
| 21 |
+
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
|
| 24 |
+
parser.add_argument("--mode",
|
| 25 |
+
default="train",
|
| 26 |
+
help="train/valid/test/debug (default: train)")
|
| 27 |
+
|
| 28 |
+
parser.add_argument("--epochs",
|
| 29 |
+
default=10,
|
| 30 |
+
type=int,
|
| 31 |
+
help="Number of epochs (default: 10)")
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--batch_size",
|
| 34 |
+
default=20,
|
| 35 |
+
type=int,
|
| 36 |
+
help="Batch size (default: 20)")
|
| 37 |
+
|
| 38 |
+
parser.add_argument("--learning_rate",
|
| 39 |
+
default=0.001,
|
| 40 |
+
type=float,
|
| 41 |
+
help="learning rate (default: 0.001)")
|
| 42 |
+
|
| 43 |
+
parser.add_argument("--decay_step",
|
| 44 |
+
default=-1,
|
| 45 |
+
type=int,
|
| 46 |
+
help="decay step (default: -1)")
|
| 47 |
+
|
| 48 |
+
parser.add_argument("--decay_rate",
|
| 49 |
+
default=0.9,
|
| 50 |
+
type=float,
|
| 51 |
+
help="decay rate (default: 0.9)")
|
| 52 |
+
|
| 53 |
+
parser.add_argument("--momentum",
|
| 54 |
+
default=0.9,
|
| 55 |
+
type=float,
|
| 56 |
+
help="momentum (default: 0.9)")
|
| 57 |
+
|
| 58 |
+
parser.add_argument("--filters_root",
|
| 59 |
+
default=8,
|
| 60 |
+
type=int,
|
| 61 |
+
help="filters root (default: 8)")
|
| 62 |
+
|
| 63 |
+
parser.add_argument("--depth",
|
| 64 |
+
default=6,
|
| 65 |
+
type=int,
|
| 66 |
+
help="depth (default: 6)")
|
| 67 |
+
|
| 68 |
+
parser.add_argument("--kernel_size",
|
| 69 |
+
nargs="+",
|
| 70 |
+
type=int,
|
| 71 |
+
default=[3, 3],
|
| 72 |
+
help="kernel size (default: [3, 3]")
|
| 73 |
+
|
| 74 |
+
parser.add_argument("--pool_size",
|
| 75 |
+
nargs="+",
|
| 76 |
+
type=int,
|
| 77 |
+
default=[2, 2],
|
| 78 |
+
help="pool size (default: [2, 2]")
|
| 79 |
+
|
| 80 |
+
parser.add_argument("--drop_rate",
|
| 81 |
+
default=0,
|
| 82 |
+
type=float,
|
| 83 |
+
help="drop out rate (default: 0)")
|
| 84 |
+
|
| 85 |
+
parser.add_argument("--dilation_rate",
|
| 86 |
+
nargs="+",
|
| 87 |
+
type=int,
|
| 88 |
+
default=[1, 1],
|
| 89 |
+
help="dilation_rate (default: [1, 1]")
|
| 90 |
+
|
| 91 |
+
parser.add_argument("--loss_type",
|
| 92 |
+
default="cross_entropy",
|
| 93 |
+
help="loss type: cross_entropy, IOU, mean_squared (default: cross_entropy)")
|
| 94 |
+
|
| 95 |
+
parser.add_argument("--weight_decay",
|
| 96 |
+
default=0,
|
| 97 |
+
type=float,
|
| 98 |
+
help="weight decay (default: 0)")
|
| 99 |
+
|
| 100 |
+
parser.add_argument("--optimizer",
|
| 101 |
+
default="adam",
|
| 102 |
+
help="optimizer: adam, momentum (default: adam)")
|
| 103 |
+
|
| 104 |
+
parser.add_argument("--summary",
|
| 105 |
+
default=True,
|
| 106 |
+
type=bool,
|
| 107 |
+
help="summary (default: True)")
|
| 108 |
+
|
| 109 |
+
parser.add_argument("--class_weights",
|
| 110 |
+
nargs="+",
|
| 111 |
+
default=[1, 1],
|
| 112 |
+
type=float,
|
| 113 |
+
help="class weights (default: [1, 1]")
|
| 114 |
+
|
| 115 |
+
parser.add_argument("--log_dir",
|
| 116 |
+
default="log",
|
| 117 |
+
help="Tensorboard log directory (default: log)")
|
| 118 |
+
|
| 119 |
+
parser.add_argument("--model_dir",
|
| 120 |
+
default=None,
|
| 121 |
+
help="Checkpoint directory")
|
| 122 |
+
|
| 123 |
+
parser.add_argument("--num_plots",
|
| 124 |
+
default=10,
|
| 125 |
+
type=int,
|
| 126 |
+
help="plotting trainning result (default: 10)")
|
| 127 |
+
|
| 128 |
+
parser.add_argument("--input_length",
|
| 129 |
+
default=None,
|
| 130 |
+
type=int,
|
| 131 |
+
help="input length")
|
| 132 |
+
parser.add_argument("--sampling_rate",
|
| 133 |
+
default=100,
|
| 134 |
+
type=int,
|
| 135 |
+
help="sampling rate of pred data in Hz (default: 100)")
|
| 136 |
+
|
| 137 |
+
parser.add_argument("--train_signal_dir",
|
| 138 |
+
default="./Dataset/train/",
|
| 139 |
+
help="Input file directory (default: ./Dataset/train/)")
|
| 140 |
+
parser.add_argument("--train_signal_list",
|
| 141 |
+
default="./Dataset/train.csv",
|
| 142 |
+
help="Input csv file (default: ./Dataset/train.csv)")
|
| 143 |
+
parser.add_argument("--train_noise_dir",
|
| 144 |
+
default="./Dataset/train/",
|
| 145 |
+
help="Input file directory (default: ./Dataset/train/)")
|
| 146 |
+
parser.add_argument("--train_noise_list",
|
| 147 |
+
default="./Dataset/train.csv",
|
| 148 |
+
help="Input csv file (default: ./Dataset/train.csv)")
|
| 149 |
+
|
| 150 |
+
parser.add_argument("--valid_signal_dir",
|
| 151 |
+
default="./Dataset/",
|
| 152 |
+
help="Input file directory (default: ./Dataset/)")
|
| 153 |
+
parser.add_argument("--valid_signal_list",
|
| 154 |
+
default=None,
|
| 155 |
+
help="Input csv file")
|
| 156 |
+
parser.add_argument("--valid_noise_dir",
|
| 157 |
+
default="./Dataset/",
|
| 158 |
+
help="Input file directory (default: ./Dataset/)")
|
| 159 |
+
parser.add_argument("--valid_noise_list",
|
| 160 |
+
default=None,
|
| 161 |
+
help="Input csv file")
|
| 162 |
+
|
| 163 |
+
parser.add_argument("--data_dir",
|
| 164 |
+
default="./Dataset/pred/",
|
| 165 |
+
help="Input file directory (default: ./Dataset/pred/)")
|
| 166 |
+
parser.add_argument("--data_list",
|
| 167 |
+
default="./Dataset/pred.csv",
|
| 168 |
+
help="Input csv file (default: ./Dataset/pred.csv)")
|
| 169 |
+
|
| 170 |
+
parser.add_argument("--output_dir",
|
| 171 |
+
default=None,
|
| 172 |
+
help="Output directory")
|
| 173 |
+
|
| 174 |
+
parser.add_argument("--fpred",
|
| 175 |
+
default="preds.npz",
|
| 176 |
+
help="ouput file name of test data")
|
| 177 |
+
parser.add_argument("--plot_figure",
|
| 178 |
+
action="store_true",
|
| 179 |
+
help="If plot figure for test")
|
| 180 |
+
parser.add_argument("--save_result",
|
| 181 |
+
action="store_true",
|
| 182 |
+
help="If save result for test")
|
| 183 |
+
|
| 184 |
+
args = parser.parse_args()
|
| 185 |
+
return args
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def set_config(args, data_reader):
|
| 189 |
+
config = Config()
|
| 190 |
+
|
| 191 |
+
config.X_shape = data_reader.X_shape
|
| 192 |
+
config.n_channel = config.X_shape[-1]
|
| 193 |
+
config.Y_shape = data_reader.Y_shape
|
| 194 |
+
config.n_class = config.Y_shape[-1]
|
| 195 |
+
|
| 196 |
+
config.depths = args.depth
|
| 197 |
+
config.filters_root = args.filters_root
|
| 198 |
+
config.kernel_size = args.kernel_size
|
| 199 |
+
config.pool_size = args.pool_size
|
| 200 |
+
config.dilation_rate = args.dilation_rate
|
| 201 |
+
config.batch_size = args.batch_size
|
| 202 |
+
config.class_weights = args.class_weights
|
| 203 |
+
config.loss_type = args.loss_type
|
| 204 |
+
config.weight_decay = args.weight_decay
|
| 205 |
+
config.optimizer = args.optimizer
|
| 206 |
+
|
| 207 |
+
config.learning_rate = args.learning_rate
|
| 208 |
+
if (args.decay_step == -1) and (args.mode == 'train'):
|
| 209 |
+
config.decay_step = data_reader.n_signal // args.batch_size
|
| 210 |
+
else:
|
| 211 |
+
config.decay_step = args.decay_step
|
| 212 |
+
config.decay_rate = args.decay_rate
|
| 213 |
+
config.momentum = args.momentum
|
| 214 |
+
|
| 215 |
+
config.summary = args.summary
|
| 216 |
+
config.drop_rate = args.drop_rate
|
| 217 |
+
config.class_weights = args.class_weights
|
| 218 |
+
|
| 219 |
+
return config
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def train_fn(args, data_reader, data_reader_valid=None):
|
| 223 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 224 |
+
log_dir = os.path.join(args.log_dir, current_time)
|
| 225 |
+
logging.info("Training log: {}".format(log_dir))
|
| 226 |
+
if not os.path.exists(log_dir):
|
| 227 |
+
os.makedirs(log_dir)
|
| 228 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
| 229 |
+
if not os.path.exists(figure_dir):
|
| 230 |
+
os.makedirs(figure_dir)
|
| 231 |
+
|
| 232 |
+
config = set_config(args, data_reader)
|
| 233 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
| 234 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 235 |
+
|
| 236 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 237 |
+
batch = data_reader.dequeue(args.batch_size)
|
| 238 |
+
if data_reader_valid is not None:
|
| 239 |
+
batch_valid = data_reader_valid.dequeue(args.batch_size)
|
| 240 |
+
|
| 241 |
+
model = UNet(config)
|
| 242 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 243 |
+
sess_config.gpu_options.allow_growth = True
|
| 244 |
+
sess_config.log_device_placement = False
|
| 245 |
+
|
| 246 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 247 |
+
|
| 248 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
| 249 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
| 250 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 251 |
+
sess.run(init)
|
| 252 |
+
|
| 253 |
+
if args.model_dir is not None:
|
| 254 |
+
logging.info("restoring models...")
|
| 255 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 256 |
+
saver.restore(sess, latest_check_point)
|
| 257 |
+
model.reset_learning_rate(sess, learning_rate=0.01, global_step=0)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
| 261 |
+
if data_reader_valid is not None:
|
| 262 |
+
threads_valid = data_reader_valid.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
| 263 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
| 264 |
+
|
| 265 |
+
total_step = 0
|
| 266 |
+
mean_loss = 0
|
| 267 |
+
pool = multiprocessing.Pool(2)
|
| 268 |
+
for epoch in range(args.epochs):
|
| 269 |
+
progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc="{}: ".format(log_dir.split("/")[-1]))
|
| 270 |
+
for step in progressbar:
|
| 271 |
+
X_batch, Y_batch = sess.run(batch)
|
| 272 |
+
loss_batch = model.train_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate)
|
| 273 |
+
if epoch < 1:
|
| 274 |
+
mean_loss = loss_batch
|
| 275 |
+
else:
|
| 276 |
+
total_step += 1
|
| 277 |
+
mean_loss += (loss_batch-mean_loss)/total_step
|
| 278 |
+
progressbar.set_description("{}: epoch={}, loss={:.6f}, mean loss={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, mean_loss))
|
| 279 |
+
flog.write("Epoch: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss))
|
| 280 |
+
saver.save(sess, os.path.join(log_dir, "model_{}.ckpt".format(epoch)))
|
| 281 |
+
|
| 282 |
+
## valid
|
| 283 |
+
if data_reader_valid is not None:
|
| 284 |
+
mean_loss_valid = 0
|
| 285 |
+
total_step_valid = 0
|
| 286 |
+
progressbar = tqdm(range(0, data_reader_valid.n_signal, args.batch_size), desc="Valid: ")
|
| 287 |
+
for step in progressbar:
|
| 288 |
+
X_batch, Y_batch = sess.run(batch_valid)
|
| 289 |
+
loss_batch, preds_batch = model.valid_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate)
|
| 290 |
+
total_step_valid += 1
|
| 291 |
+
mean_loss_valid += (loss_batch-mean_loss_valid)/total_step_valid
|
| 292 |
+
progressbar.set_description("Valid: loss={:.6f}, mean loss={:.6f}".format(loss_batch, mean_loss_valid))
|
| 293 |
+
flog.write("Valid: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss_valid))
|
| 294 |
+
|
| 295 |
+
# plot_result(epoch, args.num_plots, figure_dir, preds_batch, X_batch, Y_batch)
|
| 296 |
+
pool.map(partial(plot_result_thread,
|
| 297 |
+
epoch = epoch,
|
| 298 |
+
preds = preds_batch,
|
| 299 |
+
X = X_batch,
|
| 300 |
+
Y = Y_batch,
|
| 301 |
+
figure_dir = figure_dir),
|
| 302 |
+
range(args.num_plots))
|
| 303 |
+
|
| 304 |
+
flog.close()
|
| 305 |
+
pool.close()
|
| 306 |
+
data_reader.coord.request_stop()
|
| 307 |
+
if data_reader_valid is not None:
|
| 308 |
+
data_reader_valid.coord.request_stop()
|
| 309 |
+
try:
|
| 310 |
+
data_reader.coord.join(threads, stop_grace_period_secs=10, ignore_live_threads=True)
|
| 311 |
+
if data_reader_valid is not None:
|
| 312 |
+
data_reader_valid.coord.join(threads_valid, stop_grace_period_secs=10, ignore_live_threads=True)
|
| 313 |
+
except:
|
| 314 |
+
pass
|
| 315 |
+
sess.run(data_reader.queue.close(cancel_pending_enqueues=True))
|
| 316 |
+
if data_reader_valid is not None:
|
| 317 |
+
sess.run(data_reader_valid.queue.close(cancel_pending_enqueues=True))
|
| 318 |
+
return 0
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def test_fn(args, data_reader, figure_dir=None, result_dir=None):
|
| 322 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 323 |
+
log_dir = os.path.join(args.log_dir, args.mode, current_time)
|
| 324 |
+
logging.info("{} log: {}".format(args.mode, log_dir))
|
| 325 |
+
if not os.path.exists(log_dir):
|
| 326 |
+
os.makedirs(log_dir)
|
| 327 |
+
if (args.plot_figure == True) and (figure_dir is None):
|
| 328 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
| 329 |
+
if not os.path.exists(figure_dir):
|
| 330 |
+
os.makedirs(figure_dir)
|
| 331 |
+
if (args.save_result == True) and (result_dir is None):
|
| 332 |
+
result_dir = os.path.join(log_dir, 'results')
|
| 333 |
+
if not os.path.exists(result_dir):
|
| 334 |
+
os.makedirs(result_dir)
|
| 335 |
+
|
| 336 |
+
config = set_config(args, data_reader)
|
| 337 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
| 338 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 339 |
+
|
| 340 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 341 |
+
batch = data_reader.dequeue(args.batch_size)
|
| 342 |
+
|
| 343 |
+
model = UNet(config, input_batch=batch, mode='test')
|
| 344 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 345 |
+
sess_config.gpu_options.allow_growth = True
|
| 346 |
+
sess_config.log_device_placement = False
|
| 347 |
+
|
| 348 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 349 |
+
|
| 350 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
| 351 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
| 352 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 353 |
+
sess.run(init)
|
| 354 |
+
|
| 355 |
+
logging.info("restoring models...")
|
| 356 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 357 |
+
saver.restore(sess, latest_check_point)
|
| 358 |
+
|
| 359 |
+
threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
| 360 |
+
|
| 361 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
| 362 |
+
total_step = 0
|
| 363 |
+
mean_loss = 0
|
| 364 |
+
progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc=args.mode)
|
| 365 |
+
if args.plot_figure:
|
| 366 |
+
num_pool = multiprocessing.cpu_count()*2
|
| 367 |
+
elif args.save_result:
|
| 368 |
+
num_pool = multiprocessing.cpu_count()
|
| 369 |
+
else:
|
| 370 |
+
num_pool = 2
|
| 371 |
+
pool = multiprocessing.Pool(num_pool)
|
| 372 |
+
for step in progressbar:
|
| 373 |
+
|
| 374 |
+
if step + args.batch_size >= data_reader.n_signal:
|
| 375 |
+
for t in threads:
|
| 376 |
+
t.join()
|
| 377 |
+
sess.run(data_reader.queue.close())
|
| 378 |
+
|
| 379 |
+
loss_batch, preds_batch, X_batch, Y_batch, ratio_batch, \
|
| 380 |
+
signal_batch, noise_batch, fname_batch = model.test_on_batch(sess, summary_writer)
|
| 381 |
+
total_step += 1
|
| 382 |
+
mean_loss += (loss_batch-mean_loss)/total_step
|
| 383 |
+
progressbar.set_description("{}: loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, mean_loss))
|
| 384 |
+
flog.write("step: {}, loss: {}\n".format(step, loss_batch))
|
| 385 |
+
flog.flush()
|
| 386 |
+
|
| 387 |
+
pool.map(partial(postprocessing_test,
|
| 388 |
+
preds=preds_batch,
|
| 389 |
+
X=X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis],
|
| 390 |
+
fname=fname_batch,
|
| 391 |
+
figure_dir=figure_dir,
|
| 392 |
+
result_dir=result_dir,
|
| 393 |
+
signal_FT=signal_batch,
|
| 394 |
+
noise_FT=noise_batch),
|
| 395 |
+
range(len(X_batch)))
|
| 396 |
+
|
| 397 |
+
flog.close()
|
| 398 |
+
pool.close()
|
| 399 |
+
|
| 400 |
+
return 0
|
| 401 |
+
|
| 402 |
+
def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None):
|
| 403 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 404 |
+
if log_dir is None:
|
| 405 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
| 406 |
+
logging.info("Pred log: %s" % log_dir)
|
| 407 |
+
# logging.info("Dataset size: {}".format(data_reader.num_data))
|
| 408 |
+
if not os.path.exists(log_dir):
|
| 409 |
+
os.makedirs(log_dir)
|
| 410 |
+
if args.plot_figure:
|
| 411 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
| 412 |
+
os.makedirs(figure_dir, exist_ok=True)
|
| 413 |
+
if args.save_result:
|
| 414 |
+
result_dir = os.path.join(log_dir, 'results')
|
| 415 |
+
os.makedirs(result_dir, exist_ok=True)
|
| 416 |
+
|
| 417 |
+
config = set_config(args, data_reader)
|
| 418 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
| 419 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 420 |
+
|
| 421 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 422 |
+
data_batch = data_reader.dataset(args.batch_size)
|
| 423 |
+
|
| 424 |
+
# model = UNet(config, input_batch=batch, mode='pred')
|
| 425 |
+
model = UNet(config, mode='pred')
|
| 426 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 427 |
+
sess_config.gpu_options.allow_growth = True
|
| 428 |
+
#sess_config.log_device_placement = False
|
| 429 |
+
|
| 430 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 431 |
+
|
| 432 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 433 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 434 |
+
sess.run(init)
|
| 435 |
+
|
| 436 |
+
logging.info("restoring models...")
|
| 437 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 438 |
+
saver.restore(sess, latest_check_point)
|
| 439 |
+
|
| 440 |
+
# threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
| 441 |
+
|
| 442 |
+
if args.plot_figure:
|
| 443 |
+
num_pool = multiprocessing.cpu_count()
|
| 444 |
+
elif args.save_result:
|
| 445 |
+
num_pool = multiprocessing.cpu_count()
|
| 446 |
+
else:
|
| 447 |
+
num_pool = 2
|
| 448 |
+
multiprocessing.set_start_method('spawn')
|
| 449 |
+
pool = multiprocessing.Pool(num_pool)
|
| 450 |
+
for step in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"):
|
| 451 |
+
#if step + args.batch_size >= data_reader.n_signal:
|
| 452 |
+
# for t in threads:
|
| 453 |
+
# t.join()
|
| 454 |
+
# sess.run(data_reader.queue.close())
|
| 455 |
+
# X_batch = []
|
| 456 |
+
# ratio_batch = []
|
| 457 |
+
# fname_batch = []
|
| 458 |
+
# for i in range(step, min(step+args.batch_size, data_reader.n_signal)):
|
| 459 |
+
# X, ratio, fname = data_reader[i]
|
| 460 |
+
# if np.std(X) == 0:
|
| 461 |
+
# continue
|
| 462 |
+
# X_batch.append(X)
|
| 463 |
+
# ratio_batch.append(ratio)
|
| 464 |
+
# fname_batch.append(fname)
|
| 465 |
+
# X_batch = np.stack(X_batch, axis=0)
|
| 466 |
+
# ratio_batch = np.array(ratio_batch)
|
| 467 |
+
X_batch, ratio_batch, fname_batch = sess.run(data_batch)
|
| 468 |
+
preds_batch = sess.run(model.preds, feed_dict={model.X: X_batch,
|
| 469 |
+
model.drop_rate: 0,
|
| 470 |
+
model.is_training: False})
|
| 471 |
+
#preds_batch, X_batch, ratio_batch, fname_batch = sess.run([model.preds,
|
| 472 |
+
# batch[0],
|
| 473 |
+
# batch[1],
|
| 474 |
+
# batch[2]],
|
| 475 |
+
# feed_dict={model.drop_rate: 0,
|
| 476 |
+
# model.is_training: False})
|
| 477 |
+
|
| 478 |
+
pool.map(partial(postprocessing_pred,
|
| 479 |
+
preds = preds_batch,
|
| 480 |
+
X = X_batch*ratio_batch[:,np.newaxis,:,np.newaxis],
|
| 481 |
+
fname = [x.decode() for x in fname_batch],
|
| 482 |
+
figure_dir = figure_dir,
|
| 483 |
+
result_dir = result_dir),
|
| 484 |
+
range(len(X_batch)))
|
| 485 |
+
|
| 486 |
+
# for i in range(len(X_batch)):
|
| 487 |
+
# postprocessing_thread(i,
|
| 488 |
+
# preds = preds_batch,
|
| 489 |
+
# X = X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis],
|
| 490 |
+
# fname = fname_batch,
|
| 491 |
+
# figure_dir = figure_dir,
|
| 492 |
+
# result_dir = result_dir)
|
| 493 |
+
|
| 494 |
+
pool.close()
|
| 495 |
+
|
| 496 |
+
return 0
|
| 497 |
+
|
| 498 |
+
def main(args):
|
| 499 |
+
|
| 500 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
| 501 |
+
|
| 502 |
+
coord = tf.train.Coordinator()
|
| 503 |
+
|
| 504 |
+
if args.mode == "train":
|
| 505 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 506 |
+
data_reader = DataReader(
|
| 507 |
+
signal_dir=args.train_signal_dir,
|
| 508 |
+
signal_list=args.train_signal_list,
|
| 509 |
+
noise_dir=args.train_noise_dir,
|
| 510 |
+
noise_list=args.train_noise_list,
|
| 511 |
+
queue_size=args.batch_size*2,
|
| 512 |
+
coord=coord)
|
| 513 |
+
if (args.valid_signal_list is not None) and (args.valid_noise_list is not None):
|
| 514 |
+
data_reader_valid = DataReader(
|
| 515 |
+
signal_dir=args.valid_signal_dir,
|
| 516 |
+
signal_list=args.valid_signal_list,
|
| 517 |
+
noise_dir=args.valid_noise_dir,
|
| 518 |
+
noise_list=args.valid_noise_list,
|
| 519 |
+
queue_size=args.batch_size*2,
|
| 520 |
+
coord=coord)
|
| 521 |
+
logging.info("Dataset size: training %d, validation %d" % (data_reader.n_signal, data_reader_valid.n_signal))
|
| 522 |
+
else:
|
| 523 |
+
data_reader_valid = None
|
| 524 |
+
logging.info("Dataset size: training %d, validation 0" % (data_reader.n_signal))
|
| 525 |
+
train_fn(args, data_reader, data_reader_valid)
|
| 526 |
+
|
| 527 |
+
elif args.mode == "valid" or args.mode == "test":
|
| 528 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 529 |
+
data_reader = DataReader_test(
|
| 530 |
+
signal_dir=args.valid_signal_dir,
|
| 531 |
+
signal_list=args.valid_signal_list,
|
| 532 |
+
noise_dir=args.valid_noise_dir,
|
| 533 |
+
noise_list=args.valid_noise_list,
|
| 534 |
+
queue_size=args.batch_size*2,
|
| 535 |
+
coord=coord)
|
| 536 |
+
logging.info("Dataset Size: {}".format(data_reader.n_signal))
|
| 537 |
+
test_fn(args, data_reader)
|
| 538 |
+
|
| 539 |
+
elif args.mode == "pred":
|
| 540 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 541 |
+
data_reader = DataReader_pred(
|
| 542 |
+
signal_dir=args.data_dir,
|
| 543 |
+
signal_list=args.data_list,
|
| 544 |
+
sampling_rate=args.sampling_rate)
|
| 545 |
+
logging.info("Dataset Size: {}".format(data_reader.n_signal))
|
| 546 |
+
pred_fn(args, data_reader, log_dir=args.output_dir)
|
| 547 |
+
|
| 548 |
+
else:
|
| 549 |
+
print("mode should be: train, valid, test, debug or pred")
|
| 550 |
+
|
| 551 |
+
coord.request_stop()
|
| 552 |
+
coord.join()
|
| 553 |
+
return 0
|
| 554 |
+
|
| 555 |
+
if __name__ == '__main__':
|
| 556 |
+
args = read_args()
|
| 557 |
+
main(args)
|
deepdenoiser/util.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy
|
| 7 |
+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
|
| 8 |
+
from scipy import signal
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from data_reader import Config
|
| 12 |
+
|
| 13 |
+
matplotlib.use('agg')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def plot_result(epoch, num, figure_dir, preds, X, Y, mode="valid"):
|
| 17 |
+
config = Config()
|
| 18 |
+
for i in range(min(num, len(X))):
|
| 19 |
+
|
| 20 |
+
t, noisy_signal = scipy.signal.istft(
|
| 21 |
+
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
| 22 |
+
)
|
| 23 |
+
t, ideal_denoised_signal = scipy.signal.istft(
|
| 24 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
|
| 25 |
+
fs=config.fs,
|
| 26 |
+
nperseg=config.nperseg,
|
| 27 |
+
nfft=config.nfft,
|
| 28 |
+
boundary='zeros',
|
| 29 |
+
)
|
| 30 |
+
t, denoised_signal = scipy.signal.istft(
|
| 31 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
| 32 |
+
fs=config.fs,
|
| 33 |
+
nperseg=config.nperseg,
|
| 34 |
+
nfft=config.nfft,
|
| 35 |
+
boundary='zeros',
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
plt.figure(i)
|
| 39 |
+
fig_size = plt.gcf().get_size_inches()
|
| 40 |
+
plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
|
| 41 |
+
plt.subplot(4, 2, 1)
|
| 42 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
|
| 43 |
+
plt.title("Noisy signal")
|
| 44 |
+
plt.gca().set_xticklabels([])
|
| 45 |
+
plt.subplot(4, 2, 2)
|
| 46 |
+
plt.plot(t, noisy_signal, label='Noisy signal', linewidth=0.1)
|
| 47 |
+
signal_ylim = plt.gca().get_ylim()
|
| 48 |
+
plt.gca().set_xticklabels([])
|
| 49 |
+
plt.legend(loc='lower left')
|
| 50 |
+
plt.margins(x=0)
|
| 51 |
+
|
| 52 |
+
plt.subplot(4, 2, 3)
|
| 53 |
+
plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
|
| 54 |
+
plt.gca().set_xticklabels([])
|
| 55 |
+
plt.title("Y")
|
| 56 |
+
plt.subplot(4, 2, 4)
|
| 57 |
+
plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
|
| 58 |
+
plt.title("$\hat{Y}$")
|
| 59 |
+
plt.gca().set_xticklabels([])
|
| 60 |
+
|
| 61 |
+
plt.subplot(4, 2, 5)
|
| 62 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
|
| 63 |
+
plt.title("Ideal denoised signal")
|
| 64 |
+
plt.gca().set_xticklabels([])
|
| 65 |
+
plt.subplot(4, 2, 6)
|
| 66 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
|
| 67 |
+
plt.title("Denoised signal")
|
| 68 |
+
plt.gca().set_xticklabels([])
|
| 69 |
+
|
| 70 |
+
plt.subplot(4, 2, 7)
|
| 71 |
+
plt.plot(t, ideal_denoised_signal, label='Ideal denoised signal', linewidth=0.1)
|
| 72 |
+
plt.ylim(signal_ylim)
|
| 73 |
+
plt.xlabel("Time (s)")
|
| 74 |
+
plt.legend(loc='lower left')
|
| 75 |
+
plt.margins(x=0)
|
| 76 |
+
plt.subplot(4, 2, 8)
|
| 77 |
+
plt.plot(t, denoised_signal, label='Denoised signal', linewidth=0.1)
|
| 78 |
+
plt.ylim(signal_ylim)
|
| 79 |
+
plt.xlabel("Time (s)")
|
| 80 |
+
plt.legend(loc='lower left')
|
| 81 |
+
plt.margins(x=0)
|
| 82 |
+
|
| 83 |
+
plt.tight_layout()
|
| 84 |
+
plt.gcf().align_labels()
|
| 85 |
+
plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
|
| 86 |
+
# plt.savefig(os.path.join(figure_dir, "epoch%03d_%03d.pdf" % (epoch, i)), bbox_inches='tight')
|
| 87 |
+
plt.close(i)
|
| 88 |
+
return 0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def plot_result_thread(i, epoch, preds, X, Y, figure_dir, mode="valid"):
|
| 92 |
+
config = Config()
|
| 93 |
+
t, noisy_signal = scipy.signal.istft(
|
| 94 |
+
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
| 95 |
+
)
|
| 96 |
+
t, ideal_denoised_signal = scipy.signal.istft(
|
| 97 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
|
| 98 |
+
fs=config.fs,
|
| 99 |
+
nperseg=config.nperseg,
|
| 100 |
+
nfft=config.nfft,
|
| 101 |
+
boundary='zeros',
|
| 102 |
+
)
|
| 103 |
+
t, denoised_signal = scipy.signal.istft(
|
| 104 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
| 105 |
+
fs=config.fs,
|
| 106 |
+
nperseg=config.nperseg,
|
| 107 |
+
nfft=config.nfft,
|
| 108 |
+
boundary='zeros',
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
plt.figure(i)
|
| 112 |
+
fig_size = plt.gcf().get_size_inches()
|
| 113 |
+
plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
|
| 114 |
+
plt.subplot(4, 2, 1)
|
| 115 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
|
| 116 |
+
plt.title("Noisy signal")
|
| 117 |
+
plt.gca().set_xticklabels([])
|
| 118 |
+
plt.subplot(4, 2, 2)
|
| 119 |
+
plt.plot(t, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
|
| 120 |
+
signal_ylim = plt.gca().get_ylim()
|
| 121 |
+
plt.gca().set_xticklabels([])
|
| 122 |
+
plt.legend(loc='lower left')
|
| 123 |
+
plt.margins(x=0)
|
| 124 |
+
|
| 125 |
+
plt.subplot(4, 2, 3)
|
| 126 |
+
plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
|
| 127 |
+
plt.gca().set_xticklabels([])
|
| 128 |
+
plt.title("Y")
|
| 129 |
+
plt.subplot(4, 2, 4)
|
| 130 |
+
plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
|
| 131 |
+
plt.title("$\hat{Y}$")
|
| 132 |
+
plt.gca().set_xticklabels([])
|
| 133 |
+
|
| 134 |
+
plt.subplot(4, 2, 5)
|
| 135 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
|
| 136 |
+
plt.title("Ideal denoised signal")
|
| 137 |
+
plt.gca().set_xticklabels([])
|
| 138 |
+
plt.subplot(4, 2, 6)
|
| 139 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
|
| 140 |
+
plt.title("Denoised signal")
|
| 141 |
+
plt.gca().set_xticklabels([])
|
| 142 |
+
|
| 143 |
+
plt.subplot(4, 2, 7)
|
| 144 |
+
plt.plot(t, ideal_denoised_signal, 'k', label='Ideal denoised signal', linewidth=0.5)
|
| 145 |
+
plt.ylim(signal_ylim)
|
| 146 |
+
plt.xlabel("Time (s)")
|
| 147 |
+
plt.legend(loc='lower left')
|
| 148 |
+
plt.margins(x=0)
|
| 149 |
+
plt.subplot(4, 2, 8)
|
| 150 |
+
plt.plot(t, denoised_signal, 'k', label='Denoised signal', linewidth=0.5)
|
| 151 |
+
plt.ylim(signal_ylim)
|
| 152 |
+
plt.xlabel("Time (s)")
|
| 153 |
+
plt.legend(loc='lower left')
|
| 154 |
+
plt.margins(x=0)
|
| 155 |
+
|
| 156 |
+
plt.tight_layout()
|
| 157 |
+
plt.gcf().align_labels()
|
| 158 |
+
plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
|
| 159 |
+
plt.close(i)
|
| 160 |
+
return 0
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def postprocessing_test(
|
| 164 |
+
i, preds, X, fname, figure_dir=None, result_dir=None, signal_FT=None, noise_FT=None, data_dir=None
|
| 165 |
+
):
|
| 166 |
+
if (figure_dir is not None) or (result_dir is not None):
|
| 167 |
+
config = Config()
|
| 168 |
+
t1, noisy_signal = scipy.signal.istft(
|
| 169 |
+
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
| 170 |
+
)
|
| 171 |
+
t1, denoised_signal = scipy.signal.istft(
|
| 172 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
| 173 |
+
fs=config.fs,
|
| 174 |
+
nperseg=config.nperseg,
|
| 175 |
+
nfft=config.nfft,
|
| 176 |
+
boundary='zeros',
|
| 177 |
+
)
|
| 178 |
+
t1, denoised_noise = scipy.signal.istft(
|
| 179 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * (1 - preds[i, :, :, 0]),
|
| 180 |
+
fs=config.fs,
|
| 181 |
+
nperseg=config.nperseg,
|
| 182 |
+
nfft=config.nfft,
|
| 183 |
+
boundary='zeros',
|
| 184 |
+
)
|
| 185 |
+
t1, signal = scipy.signal.istft(
|
| 186 |
+
signal_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
| 187 |
+
)
|
| 188 |
+
t1, noise = scipy.signal.istft(
|
| 189 |
+
noise_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if result_dir is not None:
|
| 193 |
+
try:
|
| 194 |
+
np.savez(
|
| 195 |
+
os.path.join(result_dir, fname[i].decode()),
|
| 196 |
+
preds=preds[i],
|
| 197 |
+
X=X[i],
|
| 198 |
+
signal_FT=signal_FT[i],
|
| 199 |
+
noise_FT=noise_FT[i],
|
| 200 |
+
noisy_signal=noisy_signal,
|
| 201 |
+
denoised_signal=denoised_signal,
|
| 202 |
+
denoised_noise=denoised_noise,
|
| 203 |
+
signal=signal,
|
| 204 |
+
noise=noise,
|
| 205 |
+
)
|
| 206 |
+
except FileNotFoundError:
|
| 207 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
|
| 208 |
+
np.savez(
|
| 209 |
+
os.path.join(result_dir, fname[i].decode()),
|
| 210 |
+
preds=preds[i],
|
| 211 |
+
X=X[i],
|
| 212 |
+
signal_FT=signal_FT[i],
|
| 213 |
+
noise_FT=noise_FT[i],
|
| 214 |
+
noisy_signal=noisy_signal,
|
| 215 |
+
denoised_signal=denoised_signal,
|
| 216 |
+
denoised_noise=denoised_noise,
|
| 217 |
+
signal=signal,
|
| 218 |
+
noise=noise,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if figure_dir is not None:
|
| 222 |
+
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
|
| 223 |
+
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])
|
| 224 |
+
|
| 225 |
+
raw_data = None
|
| 226 |
+
if data_dir is not None:
|
| 227 |
+
raw_data = np.load(os.path.join(data_dir, fname[i].decode().split('/')[-1]))
|
| 228 |
+
itp = raw_data['itp']
|
| 229 |
+
its = raw_data['its']
|
| 230 |
+
ix1 = (750 - 50) / 100
|
| 231 |
+
ix2 = (750 + (its - itp) + 50) / 100
|
| 232 |
+
if ix2 - ix1 > 3:
|
| 233 |
+
ix2 = ix1 + 3
|
| 234 |
+
|
| 235 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 236 |
+
|
| 237 |
+
text_loc = [0.05, 0.8]
|
| 238 |
+
plt.figure(i)
|
| 239 |
+
fig_size = plt.gcf().get_size_inches()
|
| 240 |
+
plt.gcf().set_size_inches(fig_size * [1, 2])
|
| 241 |
+
plt.subplot(511)
|
| 242 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(signal_FT[i, :, :]), vmin=0, vmax=1)
|
| 243 |
+
plt.gca().set_xticklabels([])
|
| 244 |
+
plt.text(
|
| 245 |
+
text_loc[0],
|
| 246 |
+
text_loc[1],
|
| 247 |
+
'(i)',
|
| 248 |
+
horizontalalignment='center',
|
| 249 |
+
transform=plt.gca().transAxes,
|
| 250 |
+
fontsize="medium",
|
| 251 |
+
fontweight="bold",
|
| 252 |
+
bbox=box,
|
| 253 |
+
)
|
| 254 |
+
plt.subplot(512)
|
| 255 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(noise_FT[i, :, :]), vmin=0, vmax=1)
|
| 256 |
+
plt.gca().set_xticklabels([])
|
| 257 |
+
plt.text(
|
| 258 |
+
text_loc[0],
|
| 259 |
+
text_loc[1],
|
| 260 |
+
'(ii)',
|
| 261 |
+
horizontalalignment='center',
|
| 262 |
+
transform=plt.gca().transAxes,
|
| 263 |
+
fontsize="medium",
|
| 264 |
+
fontweight="bold",
|
| 265 |
+
bbox=box,
|
| 266 |
+
)
|
| 267 |
+
plt.subplot(513)
|
| 268 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=1)
|
| 269 |
+
plt.ylabel("Frequency (Hz)", fontsize='large')
|
| 270 |
+
plt.gca().set_xticklabels([])
|
| 271 |
+
plt.text(
|
| 272 |
+
text_loc[0],
|
| 273 |
+
text_loc[1],
|
| 274 |
+
'(iii)',
|
| 275 |
+
horizontalalignment='center',
|
| 276 |
+
transform=plt.gca().transAxes,
|
| 277 |
+
fontsize="medium",
|
| 278 |
+
fontweight="bold",
|
| 279 |
+
bbox=box,
|
| 280 |
+
)
|
| 281 |
+
plt.subplot(514)
|
| 282 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=1)
|
| 283 |
+
plt.gca().set_xticklabels([])
|
| 284 |
+
plt.text(
|
| 285 |
+
text_loc[0],
|
| 286 |
+
text_loc[1],
|
| 287 |
+
'(iv)',
|
| 288 |
+
horizontalalignment='center',
|
| 289 |
+
transform=plt.gca().transAxes,
|
| 290 |
+
fontsize="medium",
|
| 291 |
+
fontweight="bold",
|
| 292 |
+
bbox=box,
|
| 293 |
+
)
|
| 294 |
+
plt.subplot(515)
|
| 295 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], vmin=0, vmax=1)
|
| 296 |
+
plt.xlabel("Time (s)", fontsize='large')
|
| 297 |
+
plt.text(
|
| 298 |
+
text_loc[0],
|
| 299 |
+
text_loc[1],
|
| 300 |
+
'(v)',
|
| 301 |
+
horizontalalignment='center',
|
| 302 |
+
transform=plt.gca().transAxes,
|
| 303 |
+
fontsize="medium",
|
| 304 |
+
fontweight="bold",
|
| 305 |
+
bbox=box,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
| 310 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
| 311 |
+
except FileNotFoundError:
|
| 312 |
+
os.makedirs(
|
| 313 |
+
os.path.dirname(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png')), exist_ok=True
|
| 314 |
+
)
|
| 315 |
+
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
| 316 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
| 317 |
+
plt.close(i)
|
| 318 |
+
|
| 319 |
+
text_loc = [0.05, 0.8]
|
| 320 |
+
plt.figure(i)
|
| 321 |
+
fig_size = plt.gcf().get_size_inches()
|
| 322 |
+
plt.gcf().set_size_inches(fig_size * [1, 2])
|
| 323 |
+
|
| 324 |
+
ax3 = plt.subplot(513)
|
| 325 |
+
plt.plot(t1, noisy_signal, 'k', linewidth=0.5, label='Noisy signal')
|
| 326 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 327 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 328 |
+
plt.ylim([-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))])
|
| 329 |
+
signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
|
| 330 |
+
plt.ylim(signal_ylim)
|
| 331 |
+
plt.ylabel("Amplitude", fontsize='large')
|
| 332 |
+
plt.gca().set_xticklabels([])
|
| 333 |
+
plt.text(
|
| 334 |
+
text_loc[0],
|
| 335 |
+
text_loc[1],
|
| 336 |
+
'(iii)',
|
| 337 |
+
horizontalalignment='center',
|
| 338 |
+
transform=plt.gca().transAxes,
|
| 339 |
+
fontsize="medium",
|
| 340 |
+
fontweight="bold",
|
| 341 |
+
bbox=box,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
ax1 = plt.subplot(511)
|
| 345 |
+
plt.plot(t1, signal, 'k', linewidth=0.5, label='Signal')
|
| 346 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 347 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 348 |
+
plt.ylim(signal_ylim)
|
| 349 |
+
plt.gca().set_xticklabels([])
|
| 350 |
+
plt.text(
|
| 351 |
+
text_loc[0],
|
| 352 |
+
text_loc[1],
|
| 353 |
+
'(i)',
|
| 354 |
+
horizontalalignment='center',
|
| 355 |
+
transform=plt.gca().transAxes,
|
| 356 |
+
fontsize="medium",
|
| 357 |
+
fontweight="bold",
|
| 358 |
+
bbox=box,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
plt.subplot(512)
|
| 362 |
+
plt.plot(t1, noise, 'k', linewidth=0.5, label='Noise')
|
| 363 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 364 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 365 |
+
plt.ylim([-np.max(np.abs(noise)), np.max(np.abs(noise))])
|
| 366 |
+
noise_ylim = [-np.max(np.abs(noise[100:-100])), np.max(np.abs(noise[100:-100]))]
|
| 367 |
+
plt.ylim(noise_ylim)
|
| 368 |
+
plt.gca().set_xticklabels([])
|
| 369 |
+
plt.text(
|
| 370 |
+
text_loc[0],
|
| 371 |
+
text_loc[1],
|
| 372 |
+
'(ii)',
|
| 373 |
+
horizontalalignment='center',
|
| 374 |
+
transform=plt.gca().transAxes,
|
| 375 |
+
fontsize="medium",
|
| 376 |
+
fontweight="bold",
|
| 377 |
+
bbox=box,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
ax4 = plt.subplot(514)
|
| 381 |
+
plt.plot(t1, denoised_signal, 'k', linewidth=0.5, label='Recovered signal')
|
| 382 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 383 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 384 |
+
plt.ylim(signal_ylim)
|
| 385 |
+
plt.gca().set_xticklabels([])
|
| 386 |
+
plt.text(
|
| 387 |
+
text_loc[0],
|
| 388 |
+
text_loc[1],
|
| 389 |
+
'(iv)',
|
| 390 |
+
horizontalalignment='center',
|
| 391 |
+
transform=plt.gca().transAxes,
|
| 392 |
+
fontsize="medium",
|
| 393 |
+
fontweight="bold",
|
| 394 |
+
bbox=box,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
plt.subplot(515)
|
| 398 |
+
plt.plot(t1, denoised_noise, 'k', linewidth=0.5, label='Recovered noise')
|
| 399 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 400 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 401 |
+
plt.xlabel("Time (s)", fontsize='large')
|
| 402 |
+
plt.ylim(noise_ylim)
|
| 403 |
+
plt.text(
|
| 404 |
+
text_loc[0],
|
| 405 |
+
text_loc[1],
|
| 406 |
+
'(v)',
|
| 407 |
+
horizontalalignment='center',
|
| 408 |
+
transform=plt.gca().transAxes,
|
| 409 |
+
fontsize="medium",
|
| 410 |
+
fontweight="bold",
|
| 411 |
+
bbox=box,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if data_dir is not None:
|
| 415 |
+
axins = inset_axes(
|
| 416 |
+
ax1, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax1.transAxes
|
| 417 |
+
)
|
| 418 |
+
axins.plot(t1, signal, 'k', linewidth=0.5)
|
| 419 |
+
x1, x2 = ix1, ix2
|
| 420 |
+
y1 = -np.max(np.abs(signal[(t1 > ix1) & (t1 < ix2)]))
|
| 421 |
+
y2 = -y1
|
| 422 |
+
axins.set_xlim(x1, x2)
|
| 423 |
+
axins.set_ylim(y1, y2)
|
| 424 |
+
plt.xticks(visible=False)
|
| 425 |
+
plt.yticks(visible=False)
|
| 426 |
+
mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5")
|
| 427 |
+
|
| 428 |
+
axins = inset_axes(
|
| 429 |
+
ax3, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.3), bbox_transform=ax3.transAxes
|
| 430 |
+
)
|
| 431 |
+
axins.plot(t1, noisy_signal, 'k', linewidth=0.5)
|
| 432 |
+
x1, x2 = ix1, ix2
|
| 433 |
+
axins.set_xlim(x1, x2)
|
| 434 |
+
axins.set_ylim(y1, y2)
|
| 435 |
+
plt.xticks(visible=False)
|
| 436 |
+
plt.yticks(visible=False)
|
| 437 |
+
mark_inset(ax3, axins, loc1=1, loc2=3, fc="none", ec="0.5")
|
| 438 |
+
|
| 439 |
+
axins = inset_axes(
|
| 440 |
+
ax4, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax4.transAxes
|
| 441 |
+
)
|
| 442 |
+
axins.plot(t1, denoised_signal, 'k', linewidth=0.5)
|
| 443 |
+
x1, x2 = ix1, ix2
|
| 444 |
+
axins.set_xlim(x1, x2)
|
| 445 |
+
axins.set_ylim(y1, y2)
|
| 446 |
+
plt.xticks(visible=False)
|
| 447 |
+
plt.yticks(visible=False)
|
| 448 |
+
mark_inset(ax4, axins, loc1=1, loc2=3, fc="none", ec="0.5")
|
| 449 |
+
|
| 450 |
+
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_wave.png'), bbox_inches='tight')
|
| 451 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
|
| 452 |
+
plt.close(i)
|
| 453 |
+
|
| 454 |
+
return
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def postprocessing_pred(i, preds, X, fname, figure_dir=None, result_dir=None):
|
| 458 |
+
|
| 459 |
+
if (result_dir is not None) or (figure_dir is not None):
|
| 460 |
+
config = Config()
|
| 461 |
+
|
| 462 |
+
t1, noisy_signal = scipy.signal.istft(
|
| 463 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j),
|
| 464 |
+
fs=config.fs,
|
| 465 |
+
nperseg=config.nperseg,
|
| 466 |
+
nfft=config.nfft,
|
| 467 |
+
boundary='zeros',
|
| 468 |
+
)
|
| 469 |
+
t1, denoised_signal = scipy.signal.istft(
|
| 470 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
| 471 |
+
fs=config.fs,
|
| 472 |
+
nperseg=config.nperseg,
|
| 473 |
+
nfft=config.nfft,
|
| 474 |
+
boundary='zeros',
|
| 475 |
+
)
|
| 476 |
+
t1, denoised_noise = scipy.signal.istft(
|
| 477 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
|
| 478 |
+
fs=config.fs,
|
| 479 |
+
nperseg=config.nperseg,
|
| 480 |
+
nfft=config.nfft,
|
| 481 |
+
boundary='zeros',
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if result_dir is not None:
|
| 485 |
+
try:
|
| 486 |
+
np.savez(
|
| 487 |
+
os.path.join(result_dir, fname[i]),
|
| 488 |
+
noisy_signal=noisy_signal,
|
| 489 |
+
denoised_signal=denoised_signal,
|
| 490 |
+
denoised_noise=denoised_noise,
|
| 491 |
+
t=t1,
|
| 492 |
+
)
|
| 493 |
+
except FileNotFoundError:
|
| 494 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i])))
|
| 495 |
+
np.savez(
|
| 496 |
+
os.path.join(result_dir, fname[i]),
|
| 497 |
+
noisy_signal=noisy_signal,
|
| 498 |
+
denoised_signal=denoised_signal,
|
| 499 |
+
denoised_noise=denoised_noise,
|
| 500 |
+
t=t1,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if figure_dir is not None:
|
| 504 |
+
|
| 505 |
+
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
|
| 506 |
+
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])
|
| 507 |
+
|
| 508 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 509 |
+
text_loc = [0.05, 0.77]
|
| 510 |
+
|
| 511 |
+
plt.figure(i)
|
| 512 |
+
fig_size = plt.gcf().get_size_inches()
|
| 513 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
| 514 |
+
vmax = np.std(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j)) * 1.8
|
| 515 |
+
|
| 516 |
+
plt.subplot(311)
|
| 517 |
+
plt.pcolormesh(
|
| 518 |
+
t_FT,
|
| 519 |
+
f_FT,
|
| 520 |
+
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j),
|
| 521 |
+
vmin=0,
|
| 522 |
+
vmax=vmax,
|
| 523 |
+
shading='auto',
|
| 524 |
+
label='Noisy signal',
|
| 525 |
+
)
|
| 526 |
+
plt.gca().set_xticklabels([])
|
| 527 |
+
plt.text(
|
| 528 |
+
text_loc[0],
|
| 529 |
+
text_loc[1],
|
| 530 |
+
'(i)',
|
| 531 |
+
horizontalalignment='center',
|
| 532 |
+
transform=plt.gca().transAxes,
|
| 533 |
+
fontsize="medium",
|
| 534 |
+
fontweight="bold",
|
| 535 |
+
bbox=box,
|
| 536 |
+
)
|
| 537 |
+
plt.subplot(312)
|
| 538 |
+
plt.pcolormesh(
|
| 539 |
+
t_FT,
|
| 540 |
+
f_FT,
|
| 541 |
+
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
| 542 |
+
vmin=0,
|
| 543 |
+
vmax=vmax,
|
| 544 |
+
shading='auto',
|
| 545 |
+
label='Recovered signal',
|
| 546 |
+
)
|
| 547 |
+
plt.gca().set_xticklabels([])
|
| 548 |
+
plt.ylabel("Frequency (Hz)", fontsize='large')
|
| 549 |
+
plt.text(
|
| 550 |
+
text_loc[0],
|
| 551 |
+
text_loc[1],
|
| 552 |
+
'(ii)',
|
| 553 |
+
horizontalalignment='center',
|
| 554 |
+
transform=plt.gca().transAxes,
|
| 555 |
+
fontsize="medium",
|
| 556 |
+
fontweight="bold",
|
| 557 |
+
bbox=box,
|
| 558 |
+
)
|
| 559 |
+
plt.subplot(313)
|
| 560 |
+
plt.pcolormesh(
|
| 561 |
+
t_FT,
|
| 562 |
+
f_FT,
|
| 563 |
+
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
|
| 564 |
+
vmin=0,
|
| 565 |
+
vmax=vmax,
|
| 566 |
+
shading='auto',
|
| 567 |
+
label='Recovered noise',
|
| 568 |
+
)
|
| 569 |
+
plt.xlabel("Time (s)", fontsize='large')
|
| 570 |
+
plt.text(
|
| 571 |
+
text_loc[0],
|
| 572 |
+
text_loc[1],
|
| 573 |
+
'(iii)',
|
| 574 |
+
horizontalalignment='center',
|
| 575 |
+
transform=plt.gca().transAxes,
|
| 576 |
+
fontsize="medium",
|
| 577 |
+
fontweight="bold",
|
| 578 |
+
bbox=box,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
try:
|
| 582 |
+
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
| 583 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
| 584 |
+
except FileNotFoundError:
|
| 585 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png')), exist_ok=True)
|
| 586 |
+
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
| 587 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
| 588 |
+
plt.close(i)
|
| 589 |
+
|
| 590 |
+
plt.figure(i)
|
| 591 |
+
fig_size = plt.gcf().get_size_inches()
|
| 592 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
| 593 |
+
|
| 594 |
+
ax4 = plt.subplot(311)
|
| 595 |
+
plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
|
| 596 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 597 |
+
signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
|
| 598 |
+
plt.ylim(signal_ylim)
|
| 599 |
+
plt.gca().set_xticklabels([])
|
| 600 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 601 |
+
plt.text(
|
| 602 |
+
text_loc[0],
|
| 603 |
+
text_loc[1],
|
| 604 |
+
'(i)',
|
| 605 |
+
horizontalalignment='center',
|
| 606 |
+
transform=plt.gca().transAxes,
|
| 607 |
+
fontsize="medium",
|
| 608 |
+
fontweight="bold",
|
| 609 |
+
bbox=box,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
ax5 = plt.subplot(312)
|
| 613 |
+
plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
|
| 614 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 615 |
+
plt.ylim(signal_ylim)
|
| 616 |
+
plt.gca().set_xticklabels([])
|
| 617 |
+
plt.ylabel("Amplitude", fontsize='large')
|
| 618 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 619 |
+
plt.text(
|
| 620 |
+
text_loc[0],
|
| 621 |
+
text_loc[1],
|
| 622 |
+
'(ii)',
|
| 623 |
+
horizontalalignment='center',
|
| 624 |
+
transform=plt.gca().transAxes,
|
| 625 |
+
fontsize="medium",
|
| 626 |
+
fontweight="bold",
|
| 627 |
+
bbox=box,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
plt.subplot(313)
|
| 631 |
+
plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
|
| 632 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 633 |
+
plt.ylim(signal_ylim)
|
| 634 |
+
plt.xlabel("Time (s)", fontsize='large')
|
| 635 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 636 |
+
plt.text(
|
| 637 |
+
text_loc[0],
|
| 638 |
+
text_loc[1],
|
| 639 |
+
'(iii)',
|
| 640 |
+
horizontalalignment='center',
|
| 641 |
+
transform=plt.gca().transAxes,
|
| 642 |
+
fontsize="medium",
|
| 643 |
+
fontweight="bold",
|
| 644 |
+
bbox=box,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_wave.png'), bbox_inches='tight')
|
| 648 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
|
| 649 |
+
plt.close(i)
|
| 650 |
+
|
| 651 |
+
return
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def save_results(mask, X, fname, t0, save_signal=True, save_noise=True, result_dir="results"):
|
| 655 |
+
|
| 656 |
+
config = Config()
|
| 657 |
+
|
| 658 |
+
if save_signal:
|
| 659 |
+
_, denoised_signal = scipy.signal.istft(
|
| 660 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
|
| 661 |
+
fs=config.fs,
|
| 662 |
+
nperseg=config.nperseg,
|
| 663 |
+
nfft=config.nfft,
|
| 664 |
+
boundary='zeros',
|
| 665 |
+
) # nbt, nch, nst, nt
|
| 666 |
+
denoised_signal = np.transpose(denoised_signal, [0, 3, 2, 1]) # nbt, nt, nst, nch,
|
| 667 |
+
if save_noise:
|
| 668 |
+
_, denoised_noise = scipy.signal.istft(
|
| 669 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
|
| 670 |
+
fs=config.fs,
|
| 671 |
+
nperseg=config.nperseg,
|
| 672 |
+
nfft=config.nfft,
|
| 673 |
+
boundary='zeros',
|
| 674 |
+
)
|
| 675 |
+
denoised_noise = np.transpose(denoised_noise, [0, 3, 2, 1])
|
| 676 |
+
|
| 677 |
+
if not os.path.exists(result_dir):
|
| 678 |
+
os.makedirs(result_dir)
|
| 679 |
+
|
| 680 |
+
for i in range(len(X)):
|
| 681 |
+
np.savez(
|
| 682 |
+
os.path.join(result_dir, fname[i]),
|
| 683 |
+
data=denoised_signal[i] if save_signal else None,
|
| 684 |
+
noise=denoised_noise[i] if save_noise else None,
|
| 685 |
+
t0=t0[i],
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def plot_figures(mask, X, fname, figure_dir="figures"):
|
| 690 |
+
|
| 691 |
+
config = Config()
|
| 692 |
+
|
| 693 |
+
# plot the last channel
|
| 694 |
+
mask = mask[-1, -1, ...] # nch, nst, nf, nt, 2 => nf, nt, 2
|
| 695 |
+
X = X[-1, -1, ...]
|
| 696 |
+
|
| 697 |
+
t1, noisy_signal = scipy.signal.istft(
|
| 698 |
+
(X[..., 0] + X[..., 1] * 1j),
|
| 699 |
+
fs=config.fs,
|
| 700 |
+
nperseg=config.nperseg,
|
| 701 |
+
nfft=config.nfft,
|
| 702 |
+
boundary='zeros',
|
| 703 |
+
)
|
| 704 |
+
t1, denoised_signal = scipy.signal.istft(
|
| 705 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
|
| 706 |
+
fs=config.fs,
|
| 707 |
+
nperseg=config.nperseg,
|
| 708 |
+
nfft=config.nfft,
|
| 709 |
+
boundary='zeros',
|
| 710 |
+
)
|
| 711 |
+
t1, denoised_noise = scipy.signal.istft(
|
| 712 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
|
| 713 |
+
fs=config.fs,
|
| 714 |
+
nperseg=config.nperseg,
|
| 715 |
+
nfft=config.nfft,
|
| 716 |
+
boundary='zeros',
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
if not os.path.exists(figure_dir):
|
| 720 |
+
os.makedirs(figure_dir)
|
| 721 |
+
|
| 722 |
+
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[1])
|
| 723 |
+
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[0])
|
| 724 |
+
|
| 725 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 726 |
+
text_loc = [0.05, 0.77]
|
| 727 |
+
|
| 728 |
+
plt.figure()
|
| 729 |
+
fig_size = plt.gcf().get_size_inches()
|
| 730 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
| 731 |
+
vmax = np.std(np.abs(X[:, :, 0] + X[:, :, 1] * 1j)) * 1.8
|
| 732 |
+
|
| 733 |
+
plt.subplot(311)
|
| 734 |
+
plt.pcolormesh(
|
| 735 |
+
t_FT,
|
| 736 |
+
f_FT,
|
| 737 |
+
np.abs(X[:, :, 0] + X[:, :, 1] * 1j),
|
| 738 |
+
vmin=0,
|
| 739 |
+
vmax=vmax,
|
| 740 |
+
shading='auto',
|
| 741 |
+
label='Noisy signal',
|
| 742 |
+
)
|
| 743 |
+
plt.gca().set_xticklabels([])
|
| 744 |
+
plt.text(
|
| 745 |
+
text_loc[0],
|
| 746 |
+
text_loc[1],
|
| 747 |
+
'(i)',
|
| 748 |
+
horizontalalignment='center',
|
| 749 |
+
transform=plt.gca().transAxes,
|
| 750 |
+
fontsize="medium",
|
| 751 |
+
fontweight="bold",
|
| 752 |
+
bbox=box,
|
| 753 |
+
)
|
| 754 |
+
plt.subplot(312)
|
| 755 |
+
plt.pcolormesh(
|
| 756 |
+
t_FT,
|
| 757 |
+
f_FT,
|
| 758 |
+
np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 0],
|
| 759 |
+
vmin=0,
|
| 760 |
+
vmax=vmax,
|
| 761 |
+
shading='auto',
|
| 762 |
+
label='Recovered signal',
|
| 763 |
+
)
|
| 764 |
+
plt.gca().set_xticklabels([])
|
| 765 |
+
plt.ylabel("Frequency (Hz)", fontsize='large')
|
| 766 |
+
plt.text(
|
| 767 |
+
text_loc[0],
|
| 768 |
+
text_loc[1],
|
| 769 |
+
'(ii)',
|
| 770 |
+
horizontalalignment='center',
|
| 771 |
+
transform=plt.gca().transAxes,
|
| 772 |
+
fontsize="medium",
|
| 773 |
+
fontweight="bold",
|
| 774 |
+
bbox=box,
|
| 775 |
+
)
|
| 776 |
+
plt.subplot(313)
|
| 777 |
+
plt.pcolormesh(
|
| 778 |
+
t_FT,
|
| 779 |
+
f_FT,
|
| 780 |
+
np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 1],
|
| 781 |
+
vmin=0,
|
| 782 |
+
vmax=vmax,
|
| 783 |
+
shading='auto',
|
| 784 |
+
label='Recovered noise',
|
| 785 |
+
)
|
| 786 |
+
plt.xlabel("Time (s)", fontsize='large')
|
| 787 |
+
plt.text(
|
| 788 |
+
text_loc[0],
|
| 789 |
+
text_loc[1],
|
| 790 |
+
'(iii)',
|
| 791 |
+
horizontalalignment='center',
|
| 792 |
+
transform=plt.gca().transAxes,
|
| 793 |
+
fontsize="medium",
|
| 794 |
+
fontweight="bold",
|
| 795 |
+
bbox=box,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
try:
|
| 799 |
+
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
| 800 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
| 801 |
+
except FileNotFoundError:
|
| 802 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png')), exist_ok=True)
|
| 803 |
+
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
| 804 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
| 805 |
+
plt.close()
|
| 806 |
+
|
| 807 |
+
plt.figure()
|
| 808 |
+
fig_size = plt.gcf().get_size_inches()
|
| 809 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
| 810 |
+
|
| 811 |
+
ax4 = plt.subplot(311)
|
| 812 |
+
plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
|
| 813 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 814 |
+
signal_ylim = [-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))]
|
| 815 |
+
if signal_ylim[0] != signal_ylim[1]:
|
| 816 |
+
plt.ylim(signal_ylim)
|
| 817 |
+
plt.gca().set_xticklabels([])
|
| 818 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 819 |
+
plt.text(
|
| 820 |
+
text_loc[0],
|
| 821 |
+
text_loc[1],
|
| 822 |
+
'(i)',
|
| 823 |
+
horizontalalignment='center',
|
| 824 |
+
transform=plt.gca().transAxes,
|
| 825 |
+
fontsize="medium",
|
| 826 |
+
fontweight="bold",
|
| 827 |
+
bbox=box,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
ax5 = plt.subplot(312)
|
| 831 |
+
plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
|
| 832 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 833 |
+
if signal_ylim[0] != signal_ylim[1]:
|
| 834 |
+
plt.ylim(signal_ylim)
|
| 835 |
+
plt.gca().set_xticklabels([])
|
| 836 |
+
plt.ylabel("Amplitude", fontsize='large')
|
| 837 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 838 |
+
plt.text(
|
| 839 |
+
text_loc[0],
|
| 840 |
+
text_loc[1],
|
| 841 |
+
'(ii)',
|
| 842 |
+
horizontalalignment='center',
|
| 843 |
+
transform=plt.gca().transAxes,
|
| 844 |
+
fontsize="medium",
|
| 845 |
+
fontweight="bold",
|
| 846 |
+
bbox=box,
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
plt.subplot(313)
|
| 850 |
+
plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
|
| 851 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
| 852 |
+
if signal_ylim[0] != signal_ylim[1]:
|
| 853 |
+
plt.ylim(signal_ylim)
|
| 854 |
+
plt.xlabel("Time (s)", fontsize='large')
|
| 855 |
+
plt.legend(loc='lower left', fontsize='medium')
|
| 856 |
+
plt.text(
|
| 857 |
+
text_loc[0],
|
| 858 |
+
text_loc[1],
|
| 859 |
+
'(iii)',
|
| 860 |
+
horizontalalignment='center',
|
| 861 |
+
transform=plt.gca().transAxes,
|
| 862 |
+
fontsize="medium",
|
| 863 |
+
fontweight="bold",
|
| 864 |
+
bbox=box,
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_wave.png'), bbox_inches='tight')
|
| 868 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
|
| 869 |
+
plt.close()
|
| 870 |
+
|
| 871 |
+
return
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
if __name__ == "__main__":
|
| 875 |
+
pass
|
docs/README.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DeepDenoiser
|
| 3 |
+
emoji: 🌊
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# DeepDenoiser: Seismic Signal Denoising and Decomposition Using Deep Neural Networks
|
| 11 |
+
|
| 12 |
+
[](https://ai4eps.github.io/DeepDenoiser)
|
| 13 |
+
## 1. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) and requirements
|
| 14 |
+
- Download DeepDenoiser repository
|
| 15 |
+
```bash
|
| 16 |
+
git clone https://github.com/wayneweiqiang/DeeoDenoiser.git
|
| 17 |
+
cd DeepDenoiser
|
| 18 |
+
```
|
| 19 |
+
- Install to default environment
|
| 20 |
+
```bash
|
| 21 |
+
conda env update -f=env.yml -n base
|
| 22 |
+
```
|
| 23 |
+
- Install to "deepdenoiser" virtual envirionment
|
| 24 |
+
```bash
|
| 25 |
+
conda env create -f env.yml
|
| 26 |
+
conda activate deepdenoiser
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## 2. Pre-trained model
|
| 30 |
+
Located in directory: **model/190614-104802**
|
| 31 |
+
|
| 32 |
+
## 3. Related papers
|
| 33 |
+
- Zhu, Weiqiang, S. Mostafa Mousavi, and Gregory C. Beroza. "Seismic Signal Denoising and Decomposition Using Deep Neural Networks." arXiv preprint arXiv:1811.02695 (2018).
|
| 34 |
+
|
| 35 |
+
## 4. Interactive example
|
| 36 |
+
See details in the [notebook](https://github.com/wayneweiqiang/DeepDenoiser/blob/master/docs/example_interactive.ipynb): [example_interactive.ipynb](example_interactive.ipynb)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## 5. Batch prediction
|
| 40 |
+
See details in the [notebook](https://github.com/wayneweiqiang/DeepDenoiser/blob/master/docs/example_batch_prediction.ipynb): [example_batch_prediction.ipynb](example_batch_prediction.ipynb)
|
| 41 |
+
## 6. Train
|
| 42 |
+
### Data format
|
| 43 |
+
|
| 44 |
+
Required: two csv files for signal and noise, corresponding directories of the npz files.
|
| 45 |
+
|
| 46 |
+
The csv file contains four columns: "fname", "itp", "channels"
|
| 47 |
+
|
| 48 |
+
The npz file contains four variable: "data", "itp", "channels"
|
| 49 |
+
|
| 50 |
+
The shape of "data" variables has a shape of 9001 x 3
|
| 51 |
+
|
| 52 |
+
The variables "itp" is the data points of first P arrival times.
|
| 53 |
+
|
| 54 |
+
Note: In the demo data, for simplicity we use the waveform before itp as noise samples, so the train_noise_list is same as train_signal_list here.
|
| 55 |
+
|
| 56 |
+
~~~bash
|
| 57 |
+
python deepdenoiser/train.py --mode=train --train_signal_dir=./Dataset/train --train_signal_list=./Dataset/train.csv --train_noise_dir=./Dataset/train --train_noise_list=./Dataset/train.csv --batch_size=20
|
| 58 |
+
~~~
|
| 59 |
+
|
| 60 |
+
Please let us know of any bugs found in the code. Suggestions and collaborations are welcomed
|
docs/example_batch_prediction.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/example_interactive.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
env.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: deepdenoiser
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- conda-forge
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.7
|
| 7 |
+
- numpy
|
| 8 |
+
- scipy
|
| 9 |
+
- matplotlib
|
| 10 |
+
- pandas
|
| 11 |
+
- scikit-learn
|
| 12 |
+
- tqdm
|
| 13 |
+
- obspy
|
| 14 |
+
- uvicorn
|
| 15 |
+
- fastapi
|
| 16 |
+
- kafka-python
|
| 17 |
+
- tensorflow
|
| 18 |
+
|
| 19 |
+
|
mkdocs.yml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
site_name: "DeepDenoiser"
|
| 2 |
+
site_description: 'DeepDenoiser: Seismic Signal Denoising and Decomposition Using Deep Neural Networks'
|
| 3 |
+
site_author: 'Weiqiang Zhu'
|
| 4 |
+
docs_dir: docs/
|
| 5 |
+
repo_name: 'wayneweiqiang/DeepDenoiser'
|
| 6 |
+
repo_url: 'https://github.com/wayneweiqiang/DeepDenoiser'
|
| 7 |
+
nav:
|
| 8 |
+
- Overview: README.md
|
| 9 |
+
- Interactive Example: example_interactive.ipynb
|
| 10 |
+
- Batch Prediction: example_batch_prediction.ipynb
|
| 11 |
+
theme:
|
| 12 |
+
name: 'material'
|
| 13 |
+
plugins:
|
| 14 |
+
- mkdocs-jupyter
|
| 15 |
+
extra:
|
| 16 |
+
analytics:
|
| 17 |
+
provider: google
|
| 18 |
+
property: G-FMMP8CQRDZ
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow
|
| 2 |
+
matplotlib
|
| 3 |
+
scipy
|
| 4 |
+
pandas
|
| 5 |
+
tqdm
|