Spaces:
Runtime error
Runtime error
Commit ·
47ae719
0
Parent(s):
Duplicate from wldmr/punct-tube-gr
Browse files- .gitattributes +34 -0
- README.md +14 -0
- app.py +53 -0
- myrpunct/__init__.py +2 -0
- myrpunct/punctuate.py +174 -0
- myrpunct/utils.py +34 -0
- requirements.txt +4 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Punct Tube Gr
|
| 3 |
+
emoji: 💻
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.12.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
duplicated_from: wldmr/punct-tube-gr
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from myrpunct import RestorePuncts
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
def predict(input_text):
|
| 6 |
+
rpunct = RestorePuncts()
|
| 7 |
+
output_text = rpunct.punctuate(input_text)
|
| 8 |
+
print("Punctuation finished...")
|
| 9 |
+
|
| 10 |
+
# restore the carrige returns
|
| 11 |
+
srt_file = input_text
|
| 12 |
+
punctuated = output_text
|
| 13 |
+
|
| 14 |
+
srt_file_strip=srt_file.strip()
|
| 15 |
+
srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
|
| 16 |
+
srt_file_array=srt_file_sub.split(' ')
|
| 17 |
+
pcnt_file_array=punctuated.split(' ')
|
| 18 |
+
|
| 19 |
+
# goal: restore the break points i.e. the same number of lines as the srt file
|
| 20 |
+
# this is necessary, because each line in the srt file corresponds to a frame from the video
|
| 21 |
+
if len(srt_file_array)!=len(pcnt_file_array):
|
| 22 |
+
return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
|
| 23 |
+
pcnt_file_array_hash = []
|
| 24 |
+
for idx, item in enumerate(srt_file_array):
|
| 25 |
+
if item.endswith('#'):
|
| 26 |
+
pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
|
| 27 |
+
else:
|
| 28 |
+
pcnt_file_array_hash.append(pcnt_file_array[idx])
|
| 29 |
+
|
| 30 |
+
# assemble the array back to a string
|
| 31 |
+
pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')
|
| 32 |
+
|
| 33 |
+
return pcnt_file_cr
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
|
| 37 |
+
title = "Rpunct App"
|
| 38 |
+
description = """
|
| 39 |
+
<b>Description</b>: <br>
|
| 40 |
+
Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words. <br>
|
| 41 |
+
"""
|
| 42 |
+
examples = ["my name is clara and i live in berkeley california"]
|
| 43 |
+
|
| 44 |
+
interface = gr.Interface(fn = predict,
|
| 45 |
+
inputs = ["text"],
|
| 46 |
+
outputs = ["text"],
|
| 47 |
+
title = title,
|
| 48 |
+
description = description,
|
| 49 |
+
examples=examples,
|
| 50 |
+
allow_flagging="never")
|
| 51 |
+
|
| 52 |
+
interface.launch()
|
| 53 |
+
|
myrpunct/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .punctuate import RestorePuncts
|
| 2 |
+
print("init executed ...")
|
myrpunct/punctuate.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# 💾⚙️🔮
|
| 3 |
+
|
| 4 |
+
__author__ = "Daulet N."
|
| 5 |
+
__email__ = "daulet.nurmanbetov@gmail.com"
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from langdetect import detect
|
| 9 |
+
from simpletransformers.ner import NERModel, NERArgs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RestorePuncts:
|
| 13 |
+
def __init__(self, wrds_per_pred=250, use_cuda=False):
|
| 14 |
+
self.wrds_per_pred = wrds_per_pred
|
| 15 |
+
self.overlap_wrds = 30
|
| 16 |
+
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
|
| 17 |
+
self.model_hf = "wldmr/felflare-bert-restore-punctuation"
|
| 18 |
+
self.model_args = NERArgs()
|
| 19 |
+
self.model_args.silent = True
|
| 20 |
+
self.model_args.max_seq_length = 512
|
| 21 |
+
#self.model_args.use_multiprocessing = False
|
| 22 |
+
self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
|
| 23 |
+
#self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
|
| 24 |
+
print("class init ...")
|
| 25 |
+
print("use_multiprocessing: ",self.model_args.use_multiprocessing)
|
| 26 |
+
|
| 27 |
+
def status(self):
|
| 28 |
+
print("function called")
|
| 29 |
+
|
| 30 |
+
def punctuate(self, text: str, lang:str=''):
|
| 31 |
+
"""
|
| 32 |
+
Performs punctuation restoration on arbitrarily large text.
|
| 33 |
+
Detects if input is not English, if non-English was detected terminates predictions.
|
| 34 |
+
Overrride by supplying `lang='en'`
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
- text (str): Text to punctuate, can be few words to as large as you want.
|
| 38 |
+
- lang (str): Explicit language of input text.
|
| 39 |
+
"""
|
| 40 |
+
if not lang and len(text) > 10:
|
| 41 |
+
lang = detect(text)
|
| 42 |
+
if lang != 'en':
|
| 43 |
+
raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
|
| 44 |
+
If you are certain the input is English, pass argument lang='en' to this function.
|
| 45 |
+
Punctuate received: {text}""")
|
| 46 |
+
|
| 47 |
+
# plit up large text into bert digestable chunks
|
| 48 |
+
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
|
| 49 |
+
# predict slices
|
| 50 |
+
# full_preds_lst contains tuple of labels and logits
|
| 51 |
+
full_preds_lst = [self.predict(i['text']) for i in splits]
|
| 52 |
+
# extract predictions, and discard logits
|
| 53 |
+
preds_lst = [i[0][0] for i in full_preds_lst]
|
| 54 |
+
# join text slices
|
| 55 |
+
combined_preds = self.combine_results(text, preds_lst)
|
| 56 |
+
# create punctuated prediction
|
| 57 |
+
punct_text = self.punctuate_texts(combined_preds)
|
| 58 |
+
return punct_text
|
| 59 |
+
|
| 60 |
+
def predict(self, input_slice):
|
| 61 |
+
"""
|
| 62 |
+
Passes the unpunctuated text to the model for punctuation.
|
| 63 |
+
"""
|
| 64 |
+
predictions, raw_outputs = self.model.predict([input_slice])
|
| 65 |
+
return predictions, raw_outputs
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def split_on_toks(text, length, overlap):
|
| 69 |
+
"""
|
| 70 |
+
Splits text into predefined slices of overlapping text with indexes (offsets)
|
| 71 |
+
that tie-back to original text.
|
| 72 |
+
This is done to bypass 512 token limit on transformer models by sequentially
|
| 73 |
+
feeding chunks of < 512 toks.
|
| 74 |
+
Example output:
|
| 75 |
+
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
|
| 76 |
+
"""
|
| 77 |
+
wrds = text.replace('\n', ' ').split(" ")
|
| 78 |
+
resp = []
|
| 79 |
+
lst_chunk_idx = 0
|
| 80 |
+
i = 0
|
| 81 |
+
|
| 82 |
+
while True:
|
| 83 |
+
# words in the chunk and the overlapping portion
|
| 84 |
+
wrds_len = wrds[(length * i):(length * (i + 1))]
|
| 85 |
+
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
|
| 86 |
+
wrds_split = wrds_len + wrds_ovlp
|
| 87 |
+
|
| 88 |
+
# Break loop if no more words
|
| 89 |
+
if not wrds_split:
|
| 90 |
+
break
|
| 91 |
+
|
| 92 |
+
wrds_str = " ".join(wrds_split)
|
| 93 |
+
nxt_chunk_start_idx = len(" ".join(wrds_len))
|
| 94 |
+
lst_char_idx = len(" ".join(wrds_split))
|
| 95 |
+
|
| 96 |
+
resp_obj = {
|
| 97 |
+
"text": wrds_str,
|
| 98 |
+
"start_idx": lst_chunk_idx,
|
| 99 |
+
"end_idx": lst_char_idx + lst_chunk_idx,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
resp.append(resp_obj)
|
| 103 |
+
lst_chunk_idx += nxt_chunk_start_idx + 1
|
| 104 |
+
i += 1
|
| 105 |
+
logging.info(f"Sliced transcript into {len(resp)} slices.")
|
| 106 |
+
return resp
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def combine_results(full_text: str, text_slices):
|
| 110 |
+
"""
|
| 111 |
+
Given a full text and predictions of each slice combines predictions into a single text again.
|
| 112 |
+
Performs validataion wether text was combined correctly
|
| 113 |
+
"""
|
| 114 |
+
split_full_text = full_text.replace('\n', ' ').split(" ")
|
| 115 |
+
split_full_text = [i for i in split_full_text if i]
|
| 116 |
+
split_full_text_len = len(split_full_text)
|
| 117 |
+
output_text = []
|
| 118 |
+
index = 0
|
| 119 |
+
|
| 120 |
+
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
|
| 121 |
+
text_slices = text_slices[:-1]
|
| 122 |
+
|
| 123 |
+
for _slice in text_slices:
|
| 124 |
+
slice_wrds = len(_slice)
|
| 125 |
+
for ix, wrd in enumerate(_slice):
|
| 126 |
+
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
|
| 127 |
+
if index == split_full_text_len:
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
if split_full_text[index] == str(list(wrd.keys())[0]) and \
|
| 131 |
+
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
|
| 132 |
+
index += 1
|
| 133 |
+
pred_item_tuple = list(wrd.items())[0]
|
| 134 |
+
output_text.append(pred_item_tuple)
|
| 135 |
+
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
|
| 136 |
+
index += 1
|
| 137 |
+
pred_item_tuple = list(wrd.items())[0]
|
| 138 |
+
output_text.append(pred_item_tuple)
|
| 139 |
+
assert [i[0] for i in output_text] == split_full_text
|
| 140 |
+
return output_text
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def punctuate_texts(full_pred: list):
|
| 144 |
+
"""
|
| 145 |
+
Given a list of Predictions from the model, applies the predictions to text,
|
| 146 |
+
thus punctuating it.
|
| 147 |
+
"""
|
| 148 |
+
punct_resp = ""
|
| 149 |
+
for i in full_pred:
|
| 150 |
+
word, label = i
|
| 151 |
+
if label[-1] == "U":
|
| 152 |
+
punct_wrd = word.capitalize()
|
| 153 |
+
else:
|
| 154 |
+
punct_wrd = word
|
| 155 |
+
|
| 156 |
+
if label[0] != "O":
|
| 157 |
+
punct_wrd += label[0]
|
| 158 |
+
|
| 159 |
+
punct_resp += punct_wrd + " "
|
| 160 |
+
punct_resp = punct_resp.strip()
|
| 161 |
+
# Append trailing period if doesnt exist.
|
| 162 |
+
if punct_resp[-1].isalnum():
|
| 163 |
+
punct_resp += "."
|
| 164 |
+
return punct_resp
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
punct_model = RestorePuncts()
|
| 169 |
+
# read test file
|
| 170 |
+
with open('../tests/sample_text.txt', 'r') as fp:
|
| 171 |
+
test_sample = fp.read()
|
| 172 |
+
# predict text and print
|
| 173 |
+
punctuated = punct_model.punctuate(test_sample)
|
| 174 |
+
print(punctuated)
|
myrpunct/utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# 💾⚙️🔮
|
| 3 |
+
|
| 4 |
+
__author__ = "Daulet N."
|
| 5 |
+
__email__ = "daulet.nurmanbetov@gmail.com"
|
| 6 |
+
|
| 7 |
+
def prepare_unpunct_text(text):
|
| 8 |
+
"""
|
| 9 |
+
Given a text, normalizes it to subsequently restore punctuation
|
| 10 |
+
"""
|
| 11 |
+
formatted_txt = text.replace('\n', '').strip()
|
| 12 |
+
formatted_txt = formatted_txt.lower()
|
| 13 |
+
formatted_txt_lst = formatted_txt.split(" ")
|
| 14 |
+
punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
|
| 15 |
+
normalized_txt = " ".join([i for i in punct_strp_txt if i])
|
| 16 |
+
return normalized_txt
|
| 17 |
+
|
| 18 |
+
def strip_punct(wrd):
|
| 19 |
+
"""
|
| 20 |
+
Given a word, strips non aphanumeric characters that precede and follow it
|
| 21 |
+
"""
|
| 22 |
+
if not wrd:
|
| 23 |
+
return wrd
|
| 24 |
+
|
| 25 |
+
while not wrd[-1:].isalnum():
|
| 26 |
+
if not wrd:
|
| 27 |
+
break
|
| 28 |
+
wrd = wrd[:-1]
|
| 29 |
+
|
| 30 |
+
while not wrd[:1].isalnum():
|
| 31 |
+
if not wrd:
|
| 32 |
+
break
|
| 33 |
+
wrd = wrd[1:]
|
| 34 |
+
return wrd
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
torch
|
| 3 |
+
langdetect
|
| 4 |
+
simpletransformers
|