Spaces:
Sleeping
Sleeping
| import os | |
| from collections import defaultdict, namedtuple | |
| from datetime import datetime, timedelta | |
| from json import dumps | |
| from typing import Any, AnyStr, Dict, List, NamedTuple, Union | |
| import numpy as np | |
| import requests | |
| import tensorflow as tf | |
| from fastapi import FastAPI | |
| from kafka import KafkaProducer | |
| from pydantic import BaseModel | |
| import scipy | |
| from scipy.interpolate import interp1d | |
| from model import UNet | |
| tf.compat.v1.disable_eager_execution() | |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
| PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) | |
| JSONObject = Dict[AnyStr, Any] | |
| JSONArray = List[Any] | |
| JSONStructure = Union[JSONArray, JSONObject] | |
| app = FastAPI() | |
| X_SHAPE = [3000, 1, 3] | |
| SAMPLING_RATE = 100 | |
| # load model | |
| model = UNet(mode="pred") | |
| sess_config = tf.compat.v1.ConfigProto() | |
| sess_config.gpu_options.allow_growth = True | |
| sess = tf.compat.v1.Session(config=sess_config) | |
| saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | |
| init = tf.compat.v1.global_variables_initializer() | |
| sess.run(init) | |
| latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802") | |
| print(f"restoring model {latest_check_point}") | |
| saver.restore(sess, latest_check_point) | |
| # Kafak producer | |
| use_kafka = False | |
| # BROKER_URL = 'localhost:9092' | |
| # BROKER_URL = 'my-kafka-headless:9092' | |
| try: | |
| print("Connecting to k8s kafka") | |
| BROKER_URL = "quakeflow-kafka-headless:9092" | |
| producer = KafkaProducer( | |
| bootstrap_servers=[BROKER_URL], | |
| key_serializer=lambda x: dumps(x).encode("utf-8"), | |
| value_serializer=lambda x: dumps(x).encode("utf-8"), | |
| ) | |
| use_kafka = True | |
| print("k8s kafka connection success!") | |
| except BaseException: | |
| print("k8s Kafka connection error") | |
| try: | |
| print("Connecting to local kafka") | |
| producer = KafkaProducer( | |
| bootstrap_servers=["localhost:9092"], | |
| key_serializer=lambda x: dumps(x).encode("utf-8"), | |
| value_serializer=lambda x: dumps(x).encode("utf-8"), | |
| ) | |
| use_kafka = True | |
| print("local kafka connection success!") | |
| except BaseException: | |
| print("local Kafka connection error") | |
| def normalize_batch(data, window=200): | |
| """ | |
| data: nbn, nf, nt, 2 | |
| """ | |
| assert len(data.shape) == 4 | |
| shift = window // 2 | |
| nbt, nf, nt, nimg = data.shape | |
| ## std in slide windows | |
| data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") | |
| t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200 | |
| # print(f"nt = {nt}, nt+window//2 = {nt+window//2}") | |
| std = np.zeros([nbt, len(t)]) | |
| mean = np.zeros([nbt, len(t)]) | |
| for i in range(std.shape[1]): | |
| std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)) | |
| mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)) | |
| std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2] | |
| std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1] | |
| ## normalize data with interplated std | |
| t_interp = np.arange(nt, dtype="int") | |
| std_interp = interp1d(t, std, kind="slinear")(t_interp) | |
| std_interp[std_interp == 0] = 1.0 | |
| mean_interp = interp1d(t, mean, kind="slinear")(t_interp) | |
| data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis] | |
| if len(t) > 3: ##need to address this normalization issue in training | |
| data /= 2.0 | |
| return data | |
| def get_prediction(meta): | |
| FS = 100 | |
| NPERSEG = 30 | |
| NFFT = 60 | |
| vec = np.array(meta.vec) # [batch, nt, chn] | |
| nbt, nt, nch = vec.shape | |
| vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt] | |
| vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt] | |
| if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000 | |
| vec = vec[..., :-1] | |
| if meta.dt != 0.01: | |
| t = np.linspace(0, 1, len(vec)) | |
| t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS))) | |
| vec = interp1d(t, vec, kind="slinear")(t_interp) | |
| # sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence | |
| # vec = scipy.signal.sosfilt(sos, vec) | |
| f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary='zeros') | |
| noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2] | |
| noisy_signal[np.isnan(noisy_signal)] = 0 | |
| noisy_signal[np.isinf(noisy_signal)] = 0 | |
| X_input = normalize_batch(noisy_signal) | |
| feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False} | |
| preds = sess.run(model.preds, feed_dict=feed) | |
| _, denoised_signal = scipy.signal.istft( | |
| (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0], | |
| fs=FS, | |
| nperseg=NPERSEG, | |
| nfft=NFFT, | |
| boundary='zeros', | |
| ) | |
| # _, denoised_noise = scipy.signal.istft( | |
| # (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1], | |
| # fs=FS, | |
| # nperseg=NPERSEG, | |
| # nfft=NFFT, | |
| # boundary='zeros', | |
| # ) | |
| denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt]) | |
| denoised_signal = np.transpose(denoised_signal, [0, 2, 1]) | |
| result = meta.copy() | |
| result.vec = denoised_signal.tolist() | |
| return result | |
| class Data(BaseModel): | |
| # id: Union[List[str], str] | |
| # timestamp: Union[List[str], str] | |
| # vec: Union[List[List[List[float]]], List[List[float]]] | |
| id: List[str] | |
| timestamp: List[str] | |
| vec: List[List[List[float]]] | |
| dt: float = 0.01 | |
| def predict(data: Data): | |
| denoised = get_prediction(data) | |
| return denoised | |
| def healthz(): | |
| return {"status": "ok"} |