Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import os | |
| import re | |
| import string | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
| from simpletransformers.ner import NERModel | |
| class BERTmodel: | |
| def __init__(self, normalization="full", wrds_per_pred=256): | |
| self.normalization = normalization | |
| self.wrds_per_pred = wrds_per_pred | |
| self.overlap_wrds = 32 | |
| self.valid_labels = ["O", "F", "C", "Q"] | |
| self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""} | |
| self.model = NERModel( | |
| "bert", | |
| "/code/models/urdu", | |
| use_cuda=False, | |
| labels=self.valid_labels, | |
| args={"silent": True, "max_seq_length": 512}, | |
| ) | |
| self.patterns = { | |
| "partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+", | |
| "full": string.punctuation + "،؛؟۔٪ء‘’", | |
| } | |
| def punctuation_removal(self, text: str) -> str: | |
| if self.normalization == "partial": | |
| return re.sub(self.patterns[self.normalization], "", text).strip() | |
| else: | |
| return "".join(ch for ch in text if ch not in self.patterns[self.normalization]) | |
| def punctuate(self, text: str): | |
| text = self.punctuation_removal(text) | |
| splits = self.split_on_tokens(text) | |
| full_preds_lst = [self.predict(i["text"]) for i in splits] | |
| preds_lst = [i[0][0] for i in full_preds_lst] | |
| combined_preds = self.combine_results(text, preds_lst) | |
| punct_text = self.punctuate_texts(combined_preds) | |
| return punct_text | |
| def predict(self, input_slice): | |
| return self.model.predict([input_slice]) | |
| def split_on_tokens(self, text): | |
| wrds = text.replace("\n", " ").split() | |
| response = [] | |
| lst_chunk_idx = 0 | |
| i = 0 | |
| while True: | |
| wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred] | |
| wrds_ovlp = wrds[ | |
| (i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds | |
| ] | |
| wrds_split = wrds_len + wrds_ovlp | |
| if not wrds_split: | |
| break | |
| response_obj = { | |
| "text": " ".join(wrds_split), | |
| "start_idx": lst_chunk_idx, | |
| "end_idx": lst_chunk_idx + len(" ".join(wrds_len)), | |
| } | |
| response.append(response_obj) | |
| lst_chunk_idx += response_obj["end_idx"] + 1 | |
| i += 1 | |
| return response | |
| def combine_results(self, full_text: str, text_slices): | |
| split_full_text = full_text.replace("\n", " ").split(" ") | |
| split_full_text = [i for i in split_full_text if i] | |
| split_full_text_len = len(split_full_text) | |
| output_text = [] | |
| index = 0 | |
| if len(text_slices[-1]) <= 3 and len(text_slices) > 1: | |
| text_slices = text_slices[:-1] | |
| for slice in text_slices: | |
| slice_wrds = len(slice) | |
| for ix, wrd in enumerate(slice): | |
| if index == split_full_text_len: | |
| break | |
| if ( | |
| split_full_text[index] == str(list(wrd.keys())[0]) | |
| and ix <= slice_wrds - 3 | |
| and text_slices[-1] != slice | |
| ): | |
| index += 1 | |
| pred_item_tuple = list(wrd.items())[0] | |
| output_text.append(pred_item_tuple) | |
| elif ( | |
| split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice | |
| ): | |
| index += 1 | |
| pred_item_tuple = list(wrd.items())[0] | |
| output_text.append(pred_item_tuple) | |
| assert [i[0] for i in output_text] == split_full_text | |
| return output_text | |
| def punctuate_texts(self, full_pred: list): | |
| punct_resp = [] | |
| for punct_wrd, label in full_pred: | |
| punct_wrd += self.label_to_punct[label] | |
| if punct_wrd.endswith("‘‘"): | |
| punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘" | |
| punct_resp.append(punct_wrd) | |
| punct_resp = " ".join(punct_resp) | |
| if punct_resp[-1].isalnum(): | |
| punct_resp += "۔" | |
| return punct_resp | |
| class Urdu: | |
| def __init__(self): | |
| self.model = BERTmodel() | |
| def punctuate(self, data: str): | |
| return self.model.punctuate(data) | |