| | import io |
| | import sys |
| | import gradio as gr |
| | import srt |
| | import jiwer |
| |
|
| | from dataclasses import dataclass |
| | from dataclasses_json import dataclass_json |
| | from datetime import timedelta |
| |
|
| |
|
| | @dataclass_json |
| | @dataclass |
| | class ZHTW_Sub: |
| | start: timedelta |
| | end: timedelta |
| | zh: str |
| | tw: str |
| |
|
| | def read_srt(p): |
| | with open(p) as f: |
| | subs = list(srt.parse(f.read())) |
| | return subs |
| | |
| | def merge_sub(subs): |
| | i = 1 |
| | while i < len(subs): |
| | ps = subs[i-1] |
| | s = subs[i] |
| | if ps.end != s.start: |
| | i += 1 |
| | continue |
| |
|
| | ps.end = s.end |
| | ps.zh += f" {s.zh}" |
| | ps.tw += f" {s.tw}" |
| | subs.pop(i) |
| | return subs |
| |
|
| | def merge_sub2(subs, delta): |
| | i = 1 |
| | while i < len(subs): |
| | ps = subs[i-1] |
| | s = subs[i] |
| | if s.start - ps.end > delta: |
| | i += 1 |
| | continue |
| |
|
| | ps.end = s.end |
| | ps.zh += f" {s.zh}" |
| | ps.tw += f" {s.tw}" |
| | subs.pop(i) |
| | return subs |
| |
|
| | def filter_sub(subs): |
| | buffer = io.StringIO() |
| | stdout_bak = sys.stdout |
| | sys.stdout = buffer |
| | |
| | new_subs = [] |
| | carry_next = False |
| | for s in subs: |
| | content = s.content |
| | if '#' in s.content: |
| | print('註:標記', s.start, s.end, s.content) |
| | continue |
| | |
| | if '\n' in content: |
| | print('修:分行', '\\n', s.start, content) |
| | carry_next = True |
| | continue |
| | else: |
| | content = [content] |
| | |
| | if len(content) != 1: |
| | print('註:多行', '\\n', s.start, content) |
| | print(s.start, s.end) |
| |
|
| | tw_all, zh_all = [], [] |
| | for cnt in content: |
| | if '|' in cnt: |
| | if len(cnt.split('|')) %2 != 0: |
| | print('修:多槓', cnt.split('|')) |
| | continue |
| | tw, zh = cnt.split('|') |
| | tw, zh = (t.strip() for t in [tw, zh]) |
| | |
| | else: |
| | sp = cnt.split() |
| | if len(sp) %2!=0: |
| | print('修:不均', s.start, s.end, sp) |
| | continue |
| | else: |
| | mid = len(sp)//2 |
| | tw, zh = sp[:mid], sp[mid:] |
| | tw, zh = (' '.join(t) for t in [tw, zh]) |
| | if jiwer.cer(tw, zh) > 1: |
| | print('註:差距', s.start, s.end, 'tw:', tw, 'zh:', zh) |
| | tw_all.append(tw) |
| | zh_all.append(zh) |
| | if carry_next: |
| | new_subs[-1].zh += f" {zh}" |
| | new_subs[-1].tw += f" {tw}" |
| | new_subs[-1].end = s.end |
| | carry_next = False |
| | else: |
| | new_sub = ZHTW_Sub(s.start, s.end, zh, tw) |
| | new_subs.append(new_sub) |
| | sys.stdout = stdout_bak |
| | return new_subs, buffer |
| |
|
| | def update_yield(): |
| | buffer = [] |
| | def update_print(inp): |
| | buffer.append(str(inp)) |
| | return '\n'.join(buffer) |
| | return update_print |
| |
|
| | def parse_srt(file): |
| | if file is None: |
| | return "No file uploaded." |
| |
|
| | upd = update_yield() |
| | yield upd(file.name) |
| | subs = read_srt(file.name) |
| | yield upd(len(subs)) |
| | new_subs, logs = filter_sub(subs) |
| | yield upd(logs.getvalue()) |
| | yield upd(len(new_subs)) |
| | new_subs = merge_sub(new_subs) |
| | yield upd(len(new_subs)) |
| |
|
| | |
| | |
| | total_dur = 0 |
| | for i, it in enumerate(new_subs): |
| | if (it.end-it.start).total_seconds() > 30: |
| | yield upd(i) |
| | yield upd(str(('too long', it.end, (it.end-it.start).total_seconds(), it.tw))) |
| | total_dur += (it.end-it.start).total_seconds() |
| | yield upd("可用時長 "+str(timedelta(seconds=int(total_dur)))) |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("## SRT File Validator") |
| |
|
| | with gr.Column(): |
| | file_input = gr.File(label="Upload .srt File", file_types=[".srt"]) |
| | output_log = gr.Textbox(label="Parsing Log", lines=10, max_lines=120) |
| |
|
| | file_input.change(fn=parse_srt, inputs=file_input, outputs=output_log) |
| |
|
| | demo.launch() |
| |
|