Spaces:
Build error
Build error
Eason Lu commited on
Commit ·
bd8dd84
1
Parent(s): e3f9642
add unit test for remove punc
Browse filesFormer-commit-id: 2a7749049106a57f8954db71ed014e287a48501a
- .gitignore +2 -1
- configs/task_config.yaml +1 -1
- src/srt_util/srt.py +31 -30
- tests/test_remove_punc.py +21 -0
.gitignore
CHANGED
|
@@ -13,4 +13,5 @@ log_*.csv
|
|
| 13 |
log.csv
|
| 14 |
.chroma
|
| 15 |
*.ini
|
| 16 |
-
local_dump/
|
|
|
|
|
|
| 13 |
log.csv
|
| 14 |
.chroma
|
| 15 |
*.ini
|
| 16 |
+
local_dump/
|
| 17 |
+
.pytest_cache/
|
configs/task_config.yaml
CHANGED
|
@@ -30,6 +30,6 @@ post_process:
|
|
| 30 |
output_type:
|
| 31 |
subtitle: srt
|
| 32 |
video: False
|
| 33 |
-
bilingal:
|
| 34 |
|
| 35 |
|
|
|
|
| 30 |
output_type:
|
| 31 |
subtitle: srt
|
| 32 |
video: False
|
| 33 |
+
bilingal: True
|
| 34 |
|
| 35 |
|
src/srt_util/srt.py
CHANGED
|
@@ -11,42 +11,42 @@ from tqdm import tqdm
|
|
| 11 |
# punctuation dictionary for supported languages
|
| 12 |
punctuation_dict = {
|
| 13 |
"EN": {
|
| 14 |
-
"punc_str": ". , ? ! : ; - ( ) [ ] { }
|
| 15 |
"comma": ", ",
|
| 16 |
"sentence_end": [".", "!", "?", ";"]
|
| 17 |
},
|
| 18 |
"ES": {
|
| 19 |
-
"punc_str": ". , ? ! : ; - ( ) [ ] { }
|
| 20 |
"comma": ", ",
|
| 21 |
"sentence_end": [".", "!", "?", ";", "¡", "¿"]
|
| 22 |
},
|
| 23 |
"FR": {
|
| 24 |
-
"punc_str": ".
|
| 25 |
"comma": ", ",
|
| 26 |
"sentence_end": [".", "!", "?", ";"]
|
| 27 |
},
|
| 28 |
"DE": {
|
| 29 |
-
"punc_str": ".
|
| 30 |
"comma": ", ",
|
| 31 |
"sentence_end": [".", "!", "?", ";"]
|
| 32 |
},
|
| 33 |
"RU": {
|
| 34 |
-
"punc_str": ".
|
| 35 |
"comma": ", ",
|
| 36 |
"sentence_end": [".", "!", "?", ";"]
|
| 37 |
},
|
| 38 |
"ZH": {
|
| 39 |
-
"punc_str": "。
|
| 40 |
"comma": ",",
|
| 41 |
"sentence_end": ["。", "!", "?"]
|
| 42 |
},
|
| 43 |
"JA": {
|
| 44 |
-
"punc_str": "。
|
| 45 |
"comma": "、",
|
| 46 |
"sentence_end": ["。", "!", "?"]
|
| 47 |
},
|
| 48 |
"AR": {
|
| 49 |
-
"punc_str": ".
|
| 50 |
"comma": "، ",
|
| 51 |
"sentence_end": [".", "!", "?", ";", "؟"]
|
| 52 |
},
|
|
@@ -100,6 +100,7 @@ class SrtSegment(object):
|
|
| 100 |
self.translation = ""
|
| 101 |
else:
|
| 102 |
self.translation = args[0][3]
|
|
|
|
| 103 |
|
| 104 |
def merge_seg(self, seg):
|
| 105 |
"""
|
|
@@ -132,9 +133,11 @@ class SrtSegment(object):
|
|
| 132 |
remove punctuations in translation text
|
| 133 |
:return: None
|
| 134 |
"""
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
|
| 139 |
def __str__(self) -> str:
|
| 140 |
return f'{self.duration}\n{self.source_text}\n\n'
|
|
@@ -233,19 +236,20 @@ class SrtScript(object):
|
|
| 233 |
src_text += '\n\n'
|
| 234 |
|
| 235 |
def inner_func(target, input_str):
|
| 236 |
-
#
|
| 237 |
response = openai.ChatCompletion.create(
|
| 238 |
model="gpt-4",
|
| 239 |
messages=[
|
| 240 |
{"role": "system",
|
| 241 |
-
"content": "
|
| 242 |
-
{"role": "system", "content": "
|
| 243 |
-
{"role": "user", "content": '
|
| 244 |
],
|
| 245 |
temperature=0.15
|
| 246 |
)
|
| 247 |
return response['choices'][0]['message']['content'].strip()
|
| 248 |
|
|
|
|
| 249 |
lines = translate.split('\n\n')
|
| 250 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
| 251 |
count = 0
|
|
@@ -253,6 +257,7 @@ class SrtScript(object):
|
|
| 253 |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
| 254 |
count += 1
|
| 255 |
print("Solving Unmatched Lines|iteration {}".format(count))
|
|
|
|
| 256 |
|
| 257 |
flag = True
|
| 258 |
while flag:
|
|
@@ -262,13 +267,17 @@ class SrtScript(object):
|
|
| 262 |
except Exception as e:
|
| 263 |
print("An error has occurred during solving unmatched lines:", e)
|
| 264 |
print("Retrying...")
|
|
|
|
|
|
|
| 265 |
flag = True
|
| 266 |
lines = translate.split('\n')
|
| 267 |
|
| 268 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
| 269 |
solved = False
|
| 270 |
print("Failed Solving unmatched lines, Manually parse needed")
|
|
|
|
| 271 |
|
|
|
|
| 272 |
if not os.path.exists("./logs"):
|
| 273 |
os.mkdir("./logs")
|
| 274 |
if video_link:
|
|
@@ -287,7 +296,7 @@ class SrtScript(object):
|
|
| 287 |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
|
| 288 |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
|
| 289 |
len(self.segments)) + ',' + video_name + "\n")
|
| 290 |
-
print(lines)
|
| 291 |
|
| 292 |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
|
| 293 |
# naive way to due with merge translation problem
|
|
@@ -337,19 +346,13 @@ class SrtScript(object):
|
|
| 337 |
trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
|
| 338 |
len(trans_commas) // 2 - 1]
|
| 339 |
else:
|
| 340 |
-
|
| 341 |
-
trans_space = [m.start() for m in re.finditer(' ', translation)]
|
| 342 |
-
if len(trans_space) > 0:
|
| 343 |
-
trans_split_idx = trans_space[len(trans_space) // 2] if len(trans_space) % 2 == 1 else trans_space[
|
| 344 |
-
len(trans_space) // 2 - 1]
|
| 345 |
-
else:
|
| 346 |
-
trans_split_idx = len(translation) // 2
|
| 347 |
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
|
| 354 |
# split the time duration based on text length
|
| 355 |
time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
|
|
@@ -405,8 +408,6 @@ class SrtScript(object):
|
|
| 405 |
self.segments = segments
|
| 406 |
logging.info("check_len_and_split finished")
|
| 407 |
|
| 408 |
-
pass
|
| 409 |
-
|
| 410 |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
| 411 |
# DEPRECATED
|
| 412 |
# if sentence length >= text_threshold, split this segments to two
|
|
|
|
| 11 |
# punctuation dictionary for supported languages
|
| 12 |
punctuation_dict = {
|
| 13 |
"EN": {
|
| 14 |
+
"punc_str": ". , ? ! : ; - ( ) [ ] { }",
|
| 15 |
"comma": ", ",
|
| 16 |
"sentence_end": [".", "!", "?", ";"]
|
| 17 |
},
|
| 18 |
"ES": {
|
| 19 |
+
"punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿",
|
| 20 |
"comma": ", ",
|
| 21 |
"sentence_end": [".", "!", "?", ";", "¡", "¿"]
|
| 22 |
},
|
| 23 |
"FR": {
|
| 24 |
+
"punc_str": ".,?!:;«»—",
|
| 25 |
"comma": ", ",
|
| 26 |
"sentence_end": [".", "!", "?", ";"]
|
| 27 |
},
|
| 28 |
"DE": {
|
| 29 |
+
"punc_str": ".,?!:;„“–",
|
| 30 |
"comma": ", ",
|
| 31 |
"sentence_end": [".", "!", "?", ";"]
|
| 32 |
},
|
| 33 |
"RU": {
|
| 34 |
+
"punc_str": ".,?!:;-«»—",
|
| 35 |
"comma": ", ",
|
| 36 |
"sentence_end": [".", "!", "?", ";"]
|
| 37 |
},
|
| 38 |
"ZH": {
|
| 39 |
+
"punc_str": "。,?!:;()",
|
| 40 |
"comma": ",",
|
| 41 |
"sentence_end": ["。", "!", "?"]
|
| 42 |
},
|
| 43 |
"JA": {
|
| 44 |
+
"punc_str": "。、?!:;()",
|
| 45 |
"comma": "、",
|
| 46 |
"sentence_end": ["。", "!", "?"]
|
| 47 |
},
|
| 48 |
"AR": {
|
| 49 |
+
"punc_str": ".,?!:;-()[]،؛ ؟ «»",
|
| 50 |
"comma": "، ",
|
| 51 |
"sentence_end": [".", "!", "?", ";", "؟"]
|
| 52 |
},
|
|
|
|
| 100 |
self.translation = ""
|
| 101 |
else:
|
| 102 |
self.translation = args[0][3]
|
| 103 |
+
|
| 104 |
|
| 105 |
def merge_seg(self, seg):
|
| 106 |
"""
|
|
|
|
| 133 |
remove punctuations in translation text
|
| 134 |
:return: None
|
| 135 |
"""
|
| 136 |
+
punc_str = punctuation_dict[self.tgt_lang]["punc_str"]
|
| 137 |
+
for punc in punc_str:
|
| 138 |
+
self.translation = self.translation.replace(punc, ' ')
|
| 139 |
+
# translator = str.maketrans(punc, ' ' * len(punc))
|
| 140 |
+
# self.translation = self.translation.translate(translator)
|
| 141 |
|
| 142 |
def __str__(self) -> str:
|
| 143 |
return f'{self.duration}\n{self.source_text}\n\n'
|
|
|
|
| 236 |
src_text += '\n\n'
|
| 237 |
|
| 238 |
def inner_func(target, input_str):
|
| 239 |
+
# handling merge sentences issue.
|
| 240 |
response = openai.ChatCompletion.create(
|
| 241 |
model="gpt-4",
|
| 242 |
messages=[
|
| 243 |
{"role": "system",
|
| 244 |
+
"content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"},
|
| 245 |
+
{"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)},
|
| 246 |
+
{"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)}
|
| 247 |
],
|
| 248 |
temperature=0.15
|
| 249 |
)
|
| 250 |
return response['choices'][0]['message']['content'].strip()
|
| 251 |
|
| 252 |
+
# handling merge sentences issue.
|
| 253 |
lines = translate.split('\n\n')
|
| 254 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
| 255 |
count = 0
|
|
|
|
| 257 |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
| 258 |
count += 1
|
| 259 |
print("Solving Unmatched Lines|iteration {}".format(count))
|
| 260 |
+
logging.error("Solving Unmatched Lines|iteration {}".format(count))
|
| 261 |
|
| 262 |
flag = True
|
| 263 |
while flag:
|
|
|
|
| 267 |
except Exception as e:
|
| 268 |
print("An error has occurred during solving unmatched lines:", e)
|
| 269 |
print("Retrying...")
|
| 270 |
+
logging.error("An error has occurred during solving unmatched lines:", e)
|
| 271 |
+
logging.error("Retrying...")
|
| 272 |
flag = True
|
| 273 |
lines = translate.split('\n')
|
| 274 |
|
| 275 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
| 276 |
solved = False
|
| 277 |
print("Failed Solving unmatched lines, Manually parse needed")
|
| 278 |
+
logging.error("Failed Solving unmatched lines, Manually parse needed")
|
| 279 |
|
| 280 |
+
# FIXME: put the error log in our log file
|
| 281 |
if not os.path.exists("./logs"):
|
| 282 |
os.mkdir("./logs")
|
| 283 |
if video_link:
|
|
|
|
| 296 |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
|
| 297 |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
|
| 298 |
len(self.segments)) + ',' + video_name + "\n")
|
| 299 |
+
# print(lines)
|
| 300 |
|
| 301 |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
|
| 302 |
# naive way to due with merge translation problem
|
|
|
|
| 346 |
trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
|
| 347 |
len(trans_commas) // 2 - 1]
|
| 348 |
else:
|
| 349 |
+
trans_split_idx = len(translation) // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# to avoid split English word
|
| 352 |
+
for i in range(trans_split_idx, len(translation)):
|
| 353 |
+
if not translation[i].encode('utf-8').isalpha():
|
| 354 |
+
trans_split_idx = i
|
| 355 |
+
break
|
| 356 |
|
| 357 |
# split the time duration based on text length
|
| 358 |
time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
|
|
|
|
| 408 |
self.segments = segments
|
| 409 |
logging.info("check_len_and_split finished")
|
| 410 |
|
|
|
|
|
|
|
| 411 |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
| 412 |
# DEPRECATED
|
| 413 |
# if sentence length >= text_threshold, split this segments to two
|
tests/test_remove_punc.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('./src')
|
| 3 |
+
from srt_util.srt import SrtScript, SrtSegment
|
| 4 |
+
|
| 5 |
+
zh_test1 = "再次,如果你对一些福利感兴趣,你也可以。"
|
| 6 |
+
zh_en_test1 = "GG。Classic在我今年解说的最奇葩的系列赛中获得了胜利。"
|
| 7 |
+
|
| 8 |
+
def form_srt_class(src_lang, tgt_lang, source_text="", translation="", duration="00:00:00,740 --> 00:00:08,779"):
|
| 9 |
+
segment = [0, duration, source_text, translation, ""]
|
| 10 |
+
return SrtScript(src_lang, tgt_lang, [segment])
|
| 11 |
+
|
| 12 |
+
def test_zh():
|
| 13 |
+
srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_test1)
|
| 14 |
+
srt.remove_trans_punctuation()
|
| 15 |
+
assert srt.segments[0].translation == "再次 如果你对一些福利感兴趣 你也可以 "
|
| 16 |
+
|
| 17 |
+
def test_zh_en():
|
| 18 |
+
srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_en_test1)
|
| 19 |
+
srt.remove_trans_punctuation()
|
| 20 |
+
assert srt.segments[0].translation == "GG Classic在我今年解说的最奇葩的系列赛中获得了胜利 "
|
| 21 |
+
|