|
|
| import numpy as np |
| import onnxruntime as ort |
| from transformers import Pipeline |
| from tensorflow.keras.datasets import imdb |
| from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
|
|
| class ImdbCnnPipeline(Pipeline): |
| def __init__(self, model, tokenizer=None, **kwargs): |
| super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
|
| |
| word_index = imdb.get_word_index() |
| word_index = {k: (v + 3) for k, v in word_index.items()} |
| word_index["<PAD>"] = 0 |
| word_index["<START>"] = 1 |
| word_index["<UNK>"] = 2 |
| word_index["<UNUSED>"] = 3 |
| self.word_index = word_index |
|
|
| def _sanitize_parameters(self, **kwargs): |
| return {}, {}, {} |
|
|
| def preprocess(self, text): |
| tokens = text.lower().split() |
| encoded = [self.word_index.get(word, 2) for word in tokens] |
| padded = pad_sequences([encoded], maxlen=500, value=0, padding='post') |
| return {"input": padded.astype(np.int32)} |
|
|
| def _forward(self, model_inputs): |
| input_ids = model_inputs["input"] |
| ort_inputs = {self.model.get_inputs()[0].name: input_ids} |
| logits = self.model.run(None, ort_inputs)[0] |
| return {"logits": logits} |
|
|
| def postprocess(self, model_outputs): |
| pred = model_outputs["logits"][0][0] |
| label = "POSITIVE" if pred > 0.5 else "NEGATIVE" |
| confidence = float(pred) if pred > 0.5 else 1 - float(pred) |
| return {"label": label, "confidence": confidence} |
|
|