File size: 1,933 Bytes
3bb7bc8 f9c2c31 3bb7bc8 f9c2c31 3bb7bc8 f9c2c31 3bb7bc8 f9c2c31 3bb7bc8 f9c2c31 3bb7bc8 f9c2c31 3bb7bc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# import tensorflow as tf
# from google.protobuf.struct_pb2 import Struct
# import struct2tensor.ops.gen_decode_proto_sparse
# def create_proto_message(text):
# message = Struct()
# message.fields["task_data"].string_value = text
# return message.SerializeToString()
# class TFProtoModel:
# def __init__(self, model_path):
# self.model = tf.saved_model.load(model_path)
# self.infer = self.model.signatures['serving_default']
# def predict(self, text):
# proto_data = create_proto_message(text)
# input_tensor = tf.constant([proto_data], dtype=tf.string)
# result = self.infer(inputs=input_tensor)
# return result['outputs'].numpy()
# # Initialize model when the file is loaded
# model = TFProtoModel("model")
# # This is the function Hugging Face will call
# def pipeline(text):
# return model.predict(text)
import tensorflow as tf
from google.protobuf.struct_pb2 import Struct
from transformers import Pipeline
import struct2tensor.ops.gen_decode_proto_sparse
def create_proto_message(text):
message = Struct()
message.fields["task_data"].string_value = text
return message.SerializeToString()
class TFProtoModel(Pipeline):
def __init__(self, model_path="model"):
self.model = tf.saved_model.load(model_path)
self.infer = self.model.signatures['serving_default']
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, text):
proto_data = create_proto_message(text)
return tf.constant([proto_data], dtype=tf.string)
def _forward(self, input_tensor):
result = self.infer(inputs=input_tensor)
return result['outputs'].numpy()
def postprocess(self, model_outputs):
return {"score": float(model_outputs[0])}
pipeline = TFProtoModel
# To specify the task
task = "text-classification" # or another appropriate task type |