SnowFlash383935 commited on
Commit
4d26c39
·
verified ·
1 Parent(s): 985af55

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. pipeline_imdb_cnn.py +41 -0
  3. tokenizer_config.json +3 -0
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_type": "custom-onnx",
3
+ "pipeline_tag": "text-classification"
4
+ }
pipeline_imdb_cnn.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from transformers import Pipeline
5
+ from tensorflow.keras.datasets import imdb
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+
8
+
9
+ class ImdbCnnPipeline(Pipeline):
10
+ def __init__(self, model, tokenizer=None, **kwargs):
11
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
12
+
13
+ # Загружаем словарь
14
+ word_index = imdb.get_word_index()
15
+ word_index = {k: (v + 3) for k, v in word_index.items()}
16
+ word_index["<PAD>"] = 0
17
+ word_index["<START>"] = 1
18
+ word_index["<UNK>"] = 2
19
+ word_index["<UNUSED>"] = 3
20
+ self.word_index = word_index
21
+
22
+ def _sanitize_parameters(self, **kwargs):
23
+ return {}, {}, {}
24
+
25
+ def preprocess(self, text):
26
+ tokens = text.lower().split()
27
+ encoded = [self.word_index.get(word, 2) for word in tokens]
28
+ padded = pad_sequences([encoded], maxlen=500, value=0, padding='post')
29
+ return {"input": padded.astype(np.int32)}
30
+
31
+ def _forward(self, model_inputs):
32
+ input_ids = model_inputs["input"]
33
+ ort_inputs = {self.model.get_inputs()[0].name: input_ids}
34
+ logits = self.model.run(None, ort_inputs)[0]
35
+ return {"logits": logits}
36
+
37
+ def postprocess(self, model_outputs):
38
+ pred = model_outputs["logits"][0][0]
39
+ label = "POSITIVE" if pred > 0.5 else "NEGATIVE"
40
+ confidence = float(pred) if pred > 0.5 else 1 - float(pred)
41
+ return {"label": label, "confidence": confidence}
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "tokenizer_class": "CustomTokenizer"
3
+ }