|
|
| from fire import Fire |
| import string |
| import tensorflow as tf |
| from transformers import AutoTokenizer |
| from hazm import * |
| from transformers import pipeline |
| from transformers import TextClassificationPipeline |
| original_model = "HooshvareLab/bert-fa-base-uncased" |
| model_path = 'models' |
| def remove_punctuation(input_string): |
| translator = str.maketrans("", "", string.punctuation) |
|
|
| result = input_string.translate(translator) |
| return result |
| def predict(file_path): |
| normalizer = Normalizer() |
| tokenizer = AutoTokenizer.from_pretrained(original_model) |
| |
|
|
| with open(file_path, 'r') as file: |
| text = file.read() |
| |
| text = remove_punctuation(text) |
| text = normalizer.normalize(text) |
| |
| input_tokens = tokenizer.batch_encode_plus( |
| [text], |
| padding=True, |
| truncation=True, |
| return_tensors="tf", |
| max_length=128 |
| ) |
| input_ids = input_tokens["input_ids"] |
| attention_mask = input_tokens["attention_mask"] |
| new_model = tf.keras.models.load_model(model_path) |
| |
|
|
| print({"input_ids": input_ids, "attention_mask": attention_mask}) |
| predictions = new_model.predict([{"input_ids": input_ids, "attention_mask": attention_mask}]) |
| print(predictions[0]) |
| |
| if __name__ == '__main__': |
| Fire(predict) |