Spaces:
Build error
Build error
Eason Lu commited on
Commit ·
fccaaea
1
Parent(s): bd8dd84
add domain structures
Browse filesFormer-commit-id: 185f41371688069edd8f858df08097bd487ca488
- configs/local_launch.yaml +1 -2
- configs/task_config.yaml +1 -1
- src/srt_util/srt.py +23 -1
- src/task.py +1 -1
configs/local_launch.yaml
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# launch config for local environment
|
| 2 |
-
model: "gpt-4"
|
| 3 |
local_dump: ./local_dump
|
| 4 |
-
|
| 5 |
environ: local
|
|
|
|
| 1 |
# launch config for local environment
|
|
|
|
| 2 |
local_dump: ./local_dump
|
| 3 |
+
# dictionary_path: ./domain_dict
|
| 4 |
environ: local
|
configs/task_config.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# configuration for each task
|
| 2 |
source_lang: EN
|
| 3 |
target_lang: ZH
|
| 4 |
-
field:
|
| 5 |
|
| 6 |
# ASR config
|
| 7 |
ASR:
|
|
|
|
| 1 |
# configuration for each task
|
| 2 |
source_lang: EN
|
| 3 |
target_lang: ZH
|
| 4 |
+
field: General
|
| 5 |
|
| 6 |
# ASR config
|
| 7 |
ASR:
|
src/srt_util/srt.py
CHANGED
|
@@ -52,6 +52,8 @@ punctuation_dict = {
|
|
| 52 |
},
|
| 53 |
}
|
| 54 |
|
|
|
|
|
|
|
| 55 |
class SrtSegment(object):
|
| 56 |
def __init__(self, src_lang, tgt_lang, *args) -> None:
|
| 57 |
self.src_lang = src_lang
|
|
@@ -150,11 +152,19 @@ class SrtSegment(object):
|
|
| 150 |
|
| 151 |
|
| 152 |
class SrtScript(object):
|
| 153 |
-
def __init__(self, src_lang, tgt_lang, segments) -> None:
|
|
|
|
| 154 |
self.src_lang = src_lang
|
| 155 |
self.tgt_lang = tgt_lang
|
| 156 |
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
@classmethod
|
| 159 |
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
|
| 160 |
with open(path, 'r', encoding="utf-8") as f:
|
|
@@ -429,6 +439,12 @@ class SrtScript(object):
|
|
| 429 |
def correct_with_force_term(self):
|
| 430 |
## force term correction
|
| 431 |
logging.info("performing force term correction")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
# load term dictionary
|
| 433 |
with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
| 434 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
|
@@ -478,6 +494,12 @@ class SrtScript(object):
|
|
| 478 |
|
| 479 |
def spell_check_term(self):
|
| 480 |
logging.info("performing spell check")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
import enchant
|
| 482 |
dict = enchant.Dict('en_US')
|
| 483 |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
|
|
|
|
| 52 |
},
|
| 53 |
}
|
| 54 |
|
| 55 |
+
dict_path = "./domain_dict"
|
| 56 |
+
|
| 57 |
class SrtSegment(object):
|
| 58 |
def __init__(self, src_lang, tgt_lang, *args) -> None:
|
| 59 |
self.src_lang = src_lang
|
|
|
|
| 152 |
|
| 153 |
|
| 154 |
class SrtScript(object):
|
| 155 |
+
def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None:
|
| 156 |
+
self.domain = domain
|
| 157 |
self.src_lang = src_lang
|
| 158 |
self.tgt_lang = tgt_lang
|
| 159 |
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
|
| 160 |
|
| 161 |
+
if self.domain != "General":
|
| 162 |
+
if os.path.exists(f"{dict_path}/{self.domain}"):
|
| 163 |
+
# TODO: load dictionary
|
| 164 |
+
...
|
| 165 |
+
else:
|
| 166 |
+
logging.error(f"domain {self.domain} doesn't exist")
|
| 167 |
+
|
| 168 |
@classmethod
|
| 169 |
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
|
| 170 |
with open(path, 'r', encoding="utf-8") as f:
|
|
|
|
| 439 |
def correct_with_force_term(self):
|
| 440 |
## force term correction
|
| 441 |
logging.info("performing force term correction")
|
| 442 |
+
|
| 443 |
+
# check domain
|
| 444 |
+
if self.domain == "General":
|
| 445 |
+
logging.info("General domain could not perform correct_with_force_term. skip this step.")
|
| 446 |
+
pass
|
| 447 |
+
|
| 448 |
# load term dictionary
|
| 449 |
with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
| 450 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
|
|
|
| 494 |
|
| 495 |
def spell_check_term(self):
|
| 496 |
logging.info("performing spell check")
|
| 497 |
+
|
| 498 |
+
# check domain
|
| 499 |
+
if self.domain == "General":
|
| 500 |
+
logging.info("General domain could not perform spell_check_term. skip this step.")
|
| 501 |
+
pass
|
| 502 |
+
|
| 503 |
import enchant
|
| 504 |
dict = enchant.Dict('en_US')
|
| 505 |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
|
src/task.py
CHANGED
|
@@ -157,7 +157,7 @@ class Task:
|
|
| 157 |
# after get the transcript, release the gpu resource
|
| 158 |
torch.cuda.empty_cache()
|
| 159 |
|
| 160 |
-
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'])
|
| 161 |
# save the srt script to local
|
| 162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
| 163 |
|
|
|
|
| 157 |
# after get the transcript, release the gpu resource
|
| 158 |
torch.cuda.empty_cache()
|
| 159 |
|
| 160 |
+
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'], self.field)
|
| 161 |
# save the srt script to local
|
| 162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
| 163 |
|