| from typing import Dict, List |
| import numpy as np |
| import tensorflow as tf |
| import os |
|
|
| from phasenet.model import ModelConfig, UNet |
| from phasenet.postprocess import extract_picks |
|
|
| tf.compat.v1.disable_eager_execution() |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path=""): |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| tf.compat.v1.reset_default_graph() |
| model = UNet(mode="pred") |
| sess_config = tf.compat.v1.ConfigProto() |
| sess_config.gpu_options.allow_growth = True |
| sess = tf.compat.v1.Session(config=sess_config) |
| saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
| init = tf.compat.v1.global_variables_initializer() |
| sess.run(init) |
| latest_check_point = tf.train.latest_checkpoint(os.path.join(path, "model/190703-214543")) |
| print(f"restoring model {latest_check_point}") |
| saver.restore(sess, latest_check_point) |
|
|
| |
| self.sess = sess |
| self.model = model |
|
|
| def __call__(self, inputs: str) -> List[List[Dict[str, float]]]: |
| """ |
| Args: |
| inputs (:obj:`str`): |
| a string containing some text |
| Return: |
| A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
| - "label": A string representing what the label/class is. There can be multiple labels. |
| - "score": A score between 0 and 1 describing how confident the model is for this label/class. |
| """ |
| |
| |
| |
| |
|
|
| vec = np.array(inputs)[np.newaxis, :, np.newaxis, :] |
|
|
| feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False} |
| preds = self.sess.run(self.model.preds, feed_dict=feed) |
|
|
| picks = extract_picks(preds) |
|
|
| |
|
|
| |
| return [[{"label": "debug", "score": 0.1}]] |
|
|
|
|
| if __name__ == "__main__": |
| pipeline = PreTrainedPipeline() |
| inputs = np.random.rand(1000, 3).tolist() |
| picks = pipeline(inputs) |
|
|