| import torch | |
| import torch.nn.functional as F | |
| from transformers import BertTokenizer, BertForTokenClassification | |
| import re | |
| import string | |
| def preprocess_input_text(text): | |
| """ | |
| This function adds a [MASK] token after each word, inserts a space before every punctuation mark, | |
| and converts all words to lowercase. | |
| It returns the original words from the input text along with the preprocessed version of the input text. | |
| """ | |
| text = re.sub(r'([' + string.punctuation + '])', r' \1', text) | |
| text = re.sub(' +', ' ', text) | |
| words = text.split(" ") | |
| text = text.lower() | |
| output = [] | |
| for word in text.split(" "): | |
| output.append(word) | |
| output.append("[MASK]") | |
| return words, " ".join(output) | |
| def predict_using_trained_model_old(input_text, model_dir, device): | |
| """ | |
| This function loads a model and predicts whether each word in the input text is correct or incorrect. | |
| The output is the input text, where each word is followed by a label indicating whether the word is correct (0) or incorrect (1). | |
| """ | |
| words, input_text = preprocess_input_text(input_text) | |
| tokenizer = BertTokenizer.from_pretrained(model_dir) | |
| model = BertForTokenClassification.from_pretrained(model_dir, num_labels=2) | |
| model.to(device) | |
| tokenized_inputs = tokenizer(input_text, max_length=128, padding='max_length', truncation=True, return_tensors="pt") | |
| input_ids = tokenized_inputs["input_ids"].to(device) | |
| attention_mask = tokenized_inputs["attention_mask"].to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| predictions = torch.argmax(logits, dim=-1).squeeze().cpu().numpy() | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().numpy()) | |
| model_output = [] | |
| mask_index = 0 | |
| for token, prediction in zip(tokens, predictions): | |
| if token == "[MASK]": | |
| model_output.append(str(prediction)) | |
| mask_index += 1 | |
| elif token != "[CLS]" and token != "[SEP]" and token != "[PAD]": | |
| model_output.append(words[mask_index]) | |
| return " ".join(model_output) | |
| if __name__ == '__main__': | |
| input_text = "Model u tekstu prepoznije riječi u kojima se nalazaju pogreške." | |
| model_dir = "." | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| model_output_text = predict_using_trained_model_old(input_text, model_dir, device) | |
| print(f"Model output: {model_output_text}") | |