Spaces:
Sleeping
Sleeping
| import os | |
| from collections import defaultdict | |
| from datetime import datetime, timedelta | |
| from typing import Any, AnyStr, Dict, List, NamedTuple, Optional, Union | |
| import numpy as np | |
| import tensorflow as tf | |
| from fastapi import FastAPI, WebSocket | |
| from postprocess import extract_picks | |
| from pydantic import BaseModel | |
| from scipy.interpolate import interp1d | |
| from model import UNet | |
| PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) | |
| tf.compat.v1.disable_eager_execution() | |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
| JSONObject = Dict[AnyStr, Any] | |
| JSONArray = List[Any] | |
| JSONStructure = Union[JSONArray, JSONObject] | |
| app = FastAPI() | |
| X_SHAPE = [3000, 1, 3] | |
| SAMPLING_RATE = 100 | |
| # load model | |
| model = UNet(mode="pred") | |
| sess_config = tf.compat.v1.ConfigProto() | |
| sess_config.gpu_options.allow_growth = True | |
| sess = tf.compat.v1.Session(config=sess_config) | |
| saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | |
| init = tf.compat.v1.global_variables_initializer() | |
| sess.run(init) | |
| latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543") | |
| print(f"restoring model {latest_check_point}") | |
| saver.restore(sess, latest_check_point) | |
| def normalize_batch(data, window=3000): | |
| """ | |
| data: nsta, nt, nch | |
| """ | |
| shift = window // 2 | |
| nsta, nt, nch = data.shape | |
| # std in slide windows | |
| data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") | |
| t = np.arange(0, nt, shift, dtype="int") | |
| std = np.zeros([nsta, len(t) + 1, nch]) | |
| mean = np.zeros([nsta, len(t) + 1, nch]) | |
| for i in range(1, len(t)): | |
| std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1) | |
| mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1) | |
| t = np.append(t, nt) | |
| # std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1) | |
| # mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1) | |
| std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :] | |
| std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :] | |
| std[std == 0] = 1 | |
| # ## normalize data with interplated std | |
| t_interp = np.arange(nt, dtype="int") | |
| std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp) | |
| mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp) | |
| data = (data - mean_interp) / std_interp | |
| return data | |
| def preprocess(data): | |
| raw = data.copy() | |
| data = normalize_batch(data) | |
| if len(data.shape) == 3: | |
| data = data[:, :, np.newaxis, :] | |
| raw = raw[:, :, np.newaxis, :] | |
| return data, raw | |
| def calc_timestamp(timestamp, sec): | |
| timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec) | |
| return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] | |
| def format_picks(picks, dt, amplitudes): | |
| picks_ = [] | |
| for pick, amplitude in zip(picks, amplitudes): | |
| for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp): | |
| for idx, prob, amp in zip(idxs, probs, amps): | |
| picks_.append( | |
| { | |
| "id": pick.fname, | |
| "timestamp": calc_timestamp(pick.t0, float(idx) * dt), | |
| "prob": prob, | |
| "amp": amp, | |
| "type": "p", | |
| } | |
| ) | |
| for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp): | |
| for idx, prob, amp in zip(idxs, probs, amps): | |
| picks_.append( | |
| { | |
| "id": pick.fname, | |
| "timestamp": calc_timestamp(pick.t0, float(idx) * dt), | |
| "prob": prob, | |
| "amp": amp, | |
| "type": "s", | |
| } | |
| ) | |
| return picks_ | |
| def format_data(data): | |
| # chn2idx = {"ENZ": {"E":0, "N":1, "Z":2}, | |
| # "123": {"3":0, "2":1, "1":2}, | |
| # "12Z": {"1":0, "2":1, "Z":2}} | |
| chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2} | |
| Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)]) | |
| # Group by station | |
| chn_ = defaultdict(list) | |
| t0_ = defaultdict(list) | |
| vv_ = defaultdict(list) | |
| for i in range(len(data.id)): | |
| key = data.id[i][:-1] | |
| chn_[key].append(data.id[i][-1]) | |
| t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE) | |
| vv_[key].append(np.array(data.vec[i])) | |
| # Merge to Data tuple | |
| id_ = [] | |
| timestamp_ = [] | |
| vec_ = [] | |
| for k in chn_: | |
| id_.append(k) | |
| min_t0 = min(t0_[k]) | |
| timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]) | |
| vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]]) | |
| for i in range(len(chn_[k])): | |
| # vec[int(t0_[k][i]-min_t0):len(vv_[k][i]), chn2idx[chn_[k][i]]] = vv_[k][i][int(t0_[k][i]-min_t0):X_SHAPE[0]] - np.mean(vv_[k][i]) | |
| shift = int(t0_[k][i] - min_t0) | |
| vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean( | |
| vv_[k][i][: X_SHAPE[0] - shift] | |
| ) | |
| vec_.append(vec.tolist()) | |
| return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE) | |
| # return {"id": id_, "timestamp": timestamp_, "vec": vec_, "dt":1 / SAMPLING_RATE} | |
| def get_prediction(data, return_preds=False): | |
| vec = np.array(data.vec) | |
| vec, vec_raw = preprocess(vec) | |
| feed = {model.X: vec, model.drop_rate: 0, model.is_training: False} | |
| preds = sess.run(model.preds, feed_dict=feed) | |
| picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw) | |
| picks = [ | |
| {k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} | |
| for pick in picks | |
| ] | |
| if return_preds: | |
| return picks, preds | |
| return picks | |
| class Data(BaseModel): | |
| id: List[List[str]] | |
| timestamp: List[Union[str, float, datetime]] | |
| vec: Union[List[List[List[float]]], List[List[float]]] | |
| dt: Optional[float] = 0.01 | |
| ## gamma | |
| stations: Optional[List[Dict[str, Union[float, str]]]] = None | |
| config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None | |
| # @app.on_event("startup") | |
| # def set_default_executor(): | |
| # from concurrent.futures import ThreadPoolExecutor | |
| # import asyncio | |
| # | |
| # loop = asyncio.get_running_loop() | |
| # loop.set_default_executor( | |
| # ThreadPoolExecutor(max_workers=2) | |
| # ) | |
| def predict(data: Data): | |
| picks = get_prediction(data) | |
| return picks | |
| def predict(data: Data): | |
| picks, preds = get_prediction(data, True) | |
| return picks, preds.tolist() | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| while True: | |
| data = await websocket.receive_json() | |
| # data = json.loads(data) | |
| data = Data(**data) | |
| picks = get_prediction(data) | |
| await websocket.send_json(picks) | |
| print("PhaseNet Updating...") | |
| def healthz(): | |
| return {"status": "ok"} | |
| def greet_json(): | |
| return {"Hello": "PhaseNet!"} | |