Update README.md
Browse files
README.md
CHANGED
|
@@ -126,12 +126,12 @@ This model was trained using the [Keras](https://keras.io/) framework. All train
|
|
| 126 |
import keras
|
| 127 |
import tensorflow as tf
|
| 128 |
import numpy as np
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
"bert_mlm.h5", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
|
| 132 |
-
)
|
| 133 |
|
| 134 |
MAX_LEN = 32
|
|
|
|
| 135 |
|
| 136 |
def inference(sequence):
|
| 137 |
sequence = " ".join([c if c != "e" else "[mask]" for c in sequence])
|
|
@@ -140,10 +140,10 @@ def inference(sequence):
|
|
| 140 |
|
| 141 |
tokens = tokens + pad
|
| 142 |
input_ids = tf.convert_to_tensor(np.array([tokens]))
|
| 143 |
-
prediction =
|
| 144 |
|
| 145 |
# find masked idx token
|
| 146 |
-
masked_index = np.where(input_ids ==
|
| 147 |
masked_index = masked_index[1]
|
| 148 |
|
| 149 |
# get prediction at those masked index only
|
|
|
|
| 126 |
import keras
|
| 127 |
import tensorflow as tf
|
| 128 |
import numpy as np
|
| 129 |
+
from huggingface_hub import from_pretrained_keras
|
| 130 |
|
| 131 |
+
model = from_pretrained_keras("bookbot/id-g2p-bert")
|
|
|
|
|
|
|
| 132 |
|
| 133 |
MAX_LEN = 32
|
| 134 |
+
MASK_TOKEN_ID = 30
|
| 135 |
|
| 136 |
def inference(sequence):
|
| 137 |
sequence = " ".join([c if c != "e" else "[mask]" for c in sequence])
|
|
|
|
| 140 |
|
| 141 |
tokens = tokens + pad
|
| 142 |
input_ids = tf.convert_to_tensor(np.array([tokens]))
|
| 143 |
+
prediction = model.predict(input_ids)
|
| 144 |
|
| 145 |
# find masked idx token
|
| 146 |
+
masked_index = np.where(input_ids == MASK_TOKEN_ID)
|
| 147 |
masked_index = masked_index[1]
|
| 148 |
|
| 149 |
# get prediction at those masked index only
|