upload phasenet
Browse files- pipeline.py +3 -2
pipeline.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Dict, List
|
|
| 2 |
import numpy as np
|
| 3 |
import tensorflow as tf
|
| 4 |
import os
|
| 5 |
-
|
| 6 |
from phasenet.model import ModelConfig, UNet
|
| 7 |
from phasenet.postprocess import extract_picks
|
| 8 |
|
|
@@ -51,7 +51,7 @@ class PreTrainedPipeline():
|
|
| 51 |
# "Please implement PreTrainedPipeline __call__ function"
|
| 52 |
# )
|
| 53 |
|
| 54 |
-
vec = np.
|
| 55 |
|
| 56 |
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
|
| 57 |
preds = self.sess.run(self.model.preds, feed_dict=feed)
|
|
@@ -67,4 +67,5 @@ class PreTrainedPipeline():
|
|
| 67 |
if __name__ == "__main__":
|
| 68 |
pipeline = PreTrainedPipeline()
|
| 69 |
inputs = np.random.rand(1000, 3).tolist()
|
|
|
|
| 70 |
picks = pipeline(inputs)
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import tensorflow as tf
|
| 4 |
import os
|
| 5 |
+
import json
|
| 6 |
from phasenet.model import ModelConfig, UNet
|
| 7 |
from phasenet.postprocess import extract_picks
|
| 8 |
|
|
|
|
| 51 |
# "Please implement PreTrainedPipeline __call__ function"
|
| 52 |
# )
|
| 53 |
|
| 54 |
+
vec = np.asarray(json.loads(inputs))[np.newaxis, :, np.newaxis, :]
|
| 55 |
|
| 56 |
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
|
| 57 |
preds = self.sess.run(self.model.preds, feed_dict=feed)
|
|
|
|
| 67 |
if __name__ == "__main__":
|
| 68 |
pipeline = PreTrainedPipeline()
|
| 69 |
inputs = np.random.rand(1000, 3).tolist()
|
| 70 |
+
inputs = json.dumps(inputs)
|
| 71 |
picks = pipeline(inputs)
|