zh_align / app.py
sam-ezai's picture
Upload app.py
f4c2eca verified
import io
import jiwer
import gradio as gr
import sys
from contextlib import contextmanager
@contextmanager
def switch_to_stdout():
f = sys.stdout
yield f
def ali(ref, hyp, col=80, remove_punc=False, file=None):
if remove_punc:
tr2 = str.maketrans(
', ?.`%()⋯',
', ?.`%()⋯', ', ?.`%()⋯。、',
)
ref = ref.translate(tr2)
hyp = hyp.translate(tr2)
out = jiwer.process_characters(ref, hyp)
vis = jiwer.visualize_alignment(out)
tr = str.maketrans(
'0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz*, ?.’%()…|',
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz*, ?.`%()⋯|",
)
visl = vis.splitlines()
visl[1] = visl[1][:5] + visl[1][5:].translate(tr)
visl[2] = visl[2][:5] + visl[2][5:].translate(tr)
visl[3] = visl[3][:5] + visl[3][5:].translate(tr)
# with switch_to_stdout() as f:
if isinstance(file, io.StringIO):
f = file
for i in range(4, len(visl)):
print(visl[i], file=f)
for c in range(0, len(visl[1]), col):
if c == 0:
print(visl[1][c:c+col+5], file=f)
print(visl[2][c:c+col+5], file=f)
print(visl[3][c:c+col+5], file=f)
print('', file=f)
else:
print(" "+visl[1][c+5:c+5+col], file=f)
print(" "+visl[2][c+5:c+5+col], file=f)
print(" "+visl[3][c+5:c+5+col], file=f)
print('', file=f)
else:
with open(file, 'w') as f:
for i in range(4, len(visl)):
print(visl[i], file=f)
for c in range(0, len(visl[1]), col):
if c == 0:
print(visl[1][c:c+col+5], file=f)
print(visl[2][c:c+col+5], file=f)
print(visl[3][c:c+col+5], file=f)
print('', file=f)
else:
print(" "+visl[1][c+5:c+5+col], file=f)
print(" "+visl[2][c+5:c+5+col], file=f)
print(" "+visl[3][c+5:c+5+col], file=f)
print('', file=f)
return out
def process_ours(path):
with open(path) as f:
ours = f.read()
print(ours)
ours = [i.split('\u3000', maxsplit=1)[-1] for i in ours.splitlines()]
print(ours)
ours = ' '.join(ours)
return ours
def process_theirs(path):
with open(path) as f:
theirs = f.read().splitlines()
i = 0
# print(theirs)
while i < len(theirs):
if theirs[i].startswith('會議記錄:'):
break
i+=1
theirs = theirs[i+1:]
transcript = ' '.join(theirs)
# for i in range(0, len(theirs), 4):
# text = (theirs[i+2])
# transcript += text
return transcript
def compare_transcripts(ours_file, theirs_file, remove_punc, number_box):
ours = process_ours(ours_file)
theirs = process_theirs(theirs_file)
output = io.StringIO()
ali(theirs.replace(' ', ''), ours.replace(' ', ''), remove_punc=remove_punc, file=output, col=int(number_box))
return output.getvalue()
custom_css = """
textarea[data-testid="textbox"] {
font-family: monospace !important;
}
"""
with gr.Blocks(title="Transcript Alignment Viewer",css=custom_css) as demo:
gr.Markdown("## Transcript Alignment Viewer")
gr.Markdown("上傳請確認你用的是 `不分段會議紀錄`")
with gr.Row(equal_height=True):
ours_file = gr.File(label="Our Transcript", file_types=[".txt"], scale=1)
theirs_file = gr.File(label="Their Transcript", file_types=[".txt"], scale=1)
with gr.Row():
compare_btn = gr.Button("Generate Alignment", scale=1)
remove_punc = gr.Checkbox(label="Remove Punctuation", scale=1)
gr.Markdown("**Column size:**")
number_box = gr.Textbox(show_label=False,value="80",max_lines=1,scale=1)
output_text = gr.Textbox(
label="Alignment Output",
lines=30,
max_lines=100,
show_copy_button=True,
interactive=False,
elem_id="mono"
)
compare_btn.click(
fn=compare_transcripts,
inputs=[ours_file, theirs_file, remove_punc, number_box],
outputs=output_text
)
# demo.launch(
# # css="""
# # #output-box textarea {
# # font-family: monospace !important;
# # white-space: pre !important;
# # overflow-y: scroll;
# # height: 70vh !important;
# # }
# # """
# )
if __name__ == "__main__":
demo.launch()