|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
from google.protobuf.struct_pb2 import Struct |
|
|
from transformers import Pipeline |
|
|
import struct2tensor.ops.gen_decode_proto_sparse |
|
|
|
|
|
|
|
|
def create_proto_message(text): |
|
|
message = Struct() |
|
|
message.fields["task_data"].string_value = text |
|
|
return message.SerializeToString() |
|
|
|
|
|
class TFProtoModel(Pipeline): |
|
|
def __init__(self, model_path="model"): |
|
|
self.model = tf.saved_model.load(model_path) |
|
|
self.infer = self.model.signatures['serving_default'] |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
return {}, {}, {} |
|
|
|
|
|
def preprocess(self, text): |
|
|
proto_data = create_proto_message(text) |
|
|
return tf.constant([proto_data], dtype=tf.string) |
|
|
|
|
|
def _forward(self, input_tensor): |
|
|
result = self.infer(inputs=input_tensor) |
|
|
return result['outputs'].numpy() |
|
|
|
|
|
def postprocess(self, model_outputs): |
|
|
return {"score": float(model_outputs[0])} |
|
|
|
|
|
pipeline = TFProtoModel |
|
|
|
|
|
|
|
|
task = "text-classification" |