Spaces:
Runtime error
Runtime error
Delete myrpunct
Browse files- myrpunct/__init__.py +0 -2
- myrpunct/punctuate.py +0 -174
- myrpunct/utils.py +0 -34
myrpunct/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
from .punctuate import RestorePuncts
|
| 2 |
-
print("init executed ...")
|
|
|
|
|
|
|
|
|
myrpunct/punctuate.py
DELETED
|
@@ -1,174 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# 💾⚙️🔮
|
| 3 |
-
|
| 4 |
-
__author__ = "Daulet N."
|
| 5 |
-
__email__ = "daulet.nurmanbetov@gmail.com"
|
| 6 |
-
|
| 7 |
-
import logging
|
| 8 |
-
from langdetect import detect
|
| 9 |
-
from simpletransformers.ner import NERModel, NERArgs
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class RestorePuncts:
|
| 13 |
-
def __init__(self, wrds_per_pred=250, use_cuda=False):
|
| 14 |
-
self.wrds_per_pred = wrds_per_pred
|
| 15 |
-
self.overlap_wrds = 30
|
| 16 |
-
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
|
| 17 |
-
self.model_hf = "wldmr/felflare-bert-restore-punctuation"
|
| 18 |
-
self.model_args = NERArgs()
|
| 19 |
-
self.model_args.silent = True
|
| 20 |
-
self.model_args.max_seq_length = 512
|
| 21 |
-
#self.model_args.use_multiprocessing = False
|
| 22 |
-
self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
|
| 23 |
-
#self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
|
| 24 |
-
print("class init ...")
|
| 25 |
-
print("use_multiprocessing: ",self.model_args.use_multiprocessing)
|
| 26 |
-
|
| 27 |
-
def status(self):
|
| 28 |
-
print("function called")
|
| 29 |
-
|
| 30 |
-
def punctuate(self, text: str, lang:str=''):
|
| 31 |
-
"""
|
| 32 |
-
Performs punctuation restoration on arbitrarily large text.
|
| 33 |
-
Detects if input is not English, if non-English was detected terminates predictions.
|
| 34 |
-
Overrride by supplying `lang='en'`
|
| 35 |
-
|
| 36 |
-
Args:
|
| 37 |
-
- text (str): Text to punctuate, can be few words to as large as you want.
|
| 38 |
-
- lang (str): Explicit language of input text.
|
| 39 |
-
"""
|
| 40 |
-
if not lang and len(text) > 10:
|
| 41 |
-
lang = detect(text)
|
| 42 |
-
if lang != 'en':
|
| 43 |
-
raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
|
| 44 |
-
If you are certain the input is English, pass argument lang='en' to this function.
|
| 45 |
-
Punctuate received: {text}""")
|
| 46 |
-
|
| 47 |
-
# plit up large text into bert digestable chunks
|
| 48 |
-
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
|
| 49 |
-
# predict slices
|
| 50 |
-
# full_preds_lst contains tuple of labels and logits
|
| 51 |
-
full_preds_lst = [self.predict(i['text']) for i in splits]
|
| 52 |
-
# extract predictions, and discard logits
|
| 53 |
-
preds_lst = [i[0][0] for i in full_preds_lst]
|
| 54 |
-
# join text slices
|
| 55 |
-
combined_preds = self.combine_results(text, preds_lst)
|
| 56 |
-
# create punctuated prediction
|
| 57 |
-
punct_text = self.punctuate_texts(combined_preds)
|
| 58 |
-
return punct_text
|
| 59 |
-
|
| 60 |
-
def predict(self, input_slice):
|
| 61 |
-
"""
|
| 62 |
-
Passes the unpunctuated text to the model for punctuation.
|
| 63 |
-
"""
|
| 64 |
-
predictions, raw_outputs = self.model.predict([input_slice])
|
| 65 |
-
return predictions, raw_outputs
|
| 66 |
-
|
| 67 |
-
@staticmethod
|
| 68 |
-
def split_on_toks(text, length, overlap):
|
| 69 |
-
"""
|
| 70 |
-
Splits text into predefined slices of overlapping text with indexes (offsets)
|
| 71 |
-
that tie-back to original text.
|
| 72 |
-
This is done to bypass 512 token limit on transformer models by sequentially
|
| 73 |
-
feeding chunks of < 512 toks.
|
| 74 |
-
Example output:
|
| 75 |
-
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
|
| 76 |
-
"""
|
| 77 |
-
wrds = text.replace('\n', ' ').split(" ")
|
| 78 |
-
resp = []
|
| 79 |
-
lst_chunk_idx = 0
|
| 80 |
-
i = 0
|
| 81 |
-
|
| 82 |
-
while True:
|
| 83 |
-
# words in the chunk and the overlapping portion
|
| 84 |
-
wrds_len = wrds[(length * i):(length * (i + 1))]
|
| 85 |
-
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
|
| 86 |
-
wrds_split = wrds_len + wrds_ovlp
|
| 87 |
-
|
| 88 |
-
# Break loop if no more words
|
| 89 |
-
if not wrds_split:
|
| 90 |
-
break
|
| 91 |
-
|
| 92 |
-
wrds_str = " ".join(wrds_split)
|
| 93 |
-
nxt_chunk_start_idx = len(" ".join(wrds_len))
|
| 94 |
-
lst_char_idx = len(" ".join(wrds_split))
|
| 95 |
-
|
| 96 |
-
resp_obj = {
|
| 97 |
-
"text": wrds_str,
|
| 98 |
-
"start_idx": lst_chunk_idx,
|
| 99 |
-
"end_idx": lst_char_idx + lst_chunk_idx,
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
resp.append(resp_obj)
|
| 103 |
-
lst_chunk_idx += nxt_chunk_start_idx + 1
|
| 104 |
-
i += 1
|
| 105 |
-
logging.info(f"Sliced transcript into {len(resp)} slices.")
|
| 106 |
-
return resp
|
| 107 |
-
|
| 108 |
-
@staticmethod
|
| 109 |
-
def combine_results(full_text: str, text_slices):
|
| 110 |
-
"""
|
| 111 |
-
Given a full text and predictions of each slice combines predictions into a single text again.
|
| 112 |
-
Performs validataion wether text was combined correctly
|
| 113 |
-
"""
|
| 114 |
-
split_full_text = full_text.replace('\n', ' ').split(" ")
|
| 115 |
-
split_full_text = [i for i in split_full_text if i]
|
| 116 |
-
split_full_text_len = len(split_full_text)
|
| 117 |
-
output_text = []
|
| 118 |
-
index = 0
|
| 119 |
-
|
| 120 |
-
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
|
| 121 |
-
text_slices = text_slices[:-1]
|
| 122 |
-
|
| 123 |
-
for _slice in text_slices:
|
| 124 |
-
slice_wrds = len(_slice)
|
| 125 |
-
for ix, wrd in enumerate(_slice):
|
| 126 |
-
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
|
| 127 |
-
if index == split_full_text_len:
|
| 128 |
-
break
|
| 129 |
-
|
| 130 |
-
if split_full_text[index] == str(list(wrd.keys())[0]) and \
|
| 131 |
-
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
|
| 132 |
-
index += 1
|
| 133 |
-
pred_item_tuple = list(wrd.items())[0]
|
| 134 |
-
output_text.append(pred_item_tuple)
|
| 135 |
-
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
|
| 136 |
-
index += 1
|
| 137 |
-
pred_item_tuple = list(wrd.items())[0]
|
| 138 |
-
output_text.append(pred_item_tuple)
|
| 139 |
-
assert [i[0] for i in output_text] == split_full_text
|
| 140 |
-
return output_text
|
| 141 |
-
|
| 142 |
-
@staticmethod
|
| 143 |
-
def punctuate_texts(full_pred: list):
|
| 144 |
-
"""
|
| 145 |
-
Given a list of Predictions from the model, applies the predictions to text,
|
| 146 |
-
thus punctuating it.
|
| 147 |
-
"""
|
| 148 |
-
punct_resp = ""
|
| 149 |
-
for i in full_pred:
|
| 150 |
-
word, label = i
|
| 151 |
-
if label[-1] == "U":
|
| 152 |
-
punct_wrd = word.capitalize()
|
| 153 |
-
else:
|
| 154 |
-
punct_wrd = word
|
| 155 |
-
|
| 156 |
-
if label[0] != "O":
|
| 157 |
-
punct_wrd += label[0]
|
| 158 |
-
|
| 159 |
-
punct_resp += punct_wrd + " "
|
| 160 |
-
punct_resp = punct_resp.strip()
|
| 161 |
-
# Append trailing period if doesnt exist.
|
| 162 |
-
if punct_resp[-1].isalnum():
|
| 163 |
-
punct_resp += "."
|
| 164 |
-
return punct_resp
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
if __name__ == "__main__":
|
| 168 |
-
punct_model = RestorePuncts()
|
| 169 |
-
# read test file
|
| 170 |
-
with open('../tests/sample_text.txt', 'r') as fp:
|
| 171 |
-
test_sample = fp.read()
|
| 172 |
-
# predict text and print
|
| 173 |
-
punctuated = punct_model.punctuate(test_sample)
|
| 174 |
-
print(punctuated)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
myrpunct/utils.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# 💾⚙️🔮
|
| 3 |
-
|
| 4 |
-
__author__ = "Daulet N."
|
| 5 |
-
__email__ = "daulet.nurmanbetov@gmail.com"
|
| 6 |
-
|
| 7 |
-
def prepare_unpunct_text(text):
|
| 8 |
-
"""
|
| 9 |
-
Given a text, normalizes it to subsequently restore punctuation
|
| 10 |
-
"""
|
| 11 |
-
formatted_txt = text.replace('\n', '').strip()
|
| 12 |
-
formatted_txt = formatted_txt.lower()
|
| 13 |
-
formatted_txt_lst = formatted_txt.split(" ")
|
| 14 |
-
punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
|
| 15 |
-
normalized_txt = " ".join([i for i in punct_strp_txt if i])
|
| 16 |
-
return normalized_txt
|
| 17 |
-
|
| 18 |
-
def strip_punct(wrd):
|
| 19 |
-
"""
|
| 20 |
-
Given a word, strips non aphanumeric characters that precede and follow it
|
| 21 |
-
"""
|
| 22 |
-
if not wrd:
|
| 23 |
-
return wrd
|
| 24 |
-
|
| 25 |
-
while not wrd[-1:].isalnum():
|
| 26 |
-
if not wrd:
|
| 27 |
-
break
|
| 28 |
-
wrd = wrd[:-1]
|
| 29 |
-
|
| 30 |
-
while not wrd[:1].isalnum():
|
| 31 |
-
if not wrd:
|
| 32 |
-
break
|
| 33 |
-
wrd = wrd[1:]
|
| 34 |
-
return wrd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|