import io import jiwer import gradio as gr import sys from contextlib import contextmanager @contextmanager def switch_to_stdout(): f = sys.stdout yield f def ali(ref, hyp, col=80, remove_punc=False, file=None): if remove_punc: tr2 = str.maketrans( ', ?.`%()⋯', ', ?.`%()⋯', ', ?.`%()⋯。、', ) ref = ref.translate(tr2) hyp = hyp.translate(tr2) out = jiwer.process_characters(ref, hyp) vis = jiwer.visualize_alignment(out) tr = str.maketrans( '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz*, ?.’%()…|', "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz*, ?.`%()⋯|", ) visl = vis.splitlines() visl[1] = visl[1][:5] + visl[1][5:].translate(tr) visl[2] = visl[2][:5] + visl[2][5:].translate(tr) visl[3] = visl[3][:5] + visl[3][5:].translate(tr) # with switch_to_stdout() as f: if isinstance(file, io.StringIO): f = file for i in range(4, len(visl)): print(visl[i], file=f) for c in range(0, len(visl[1]), col): if c == 0: print(visl[1][c:c+col+5], file=f) print(visl[2][c:c+col+5], file=f) print(visl[3][c:c+col+5], file=f) print('', file=f) else: print(" "+visl[1][c+5:c+5+col], file=f) print(" "+visl[2][c+5:c+5+col], file=f) print(" "+visl[3][c+5:c+5+col], file=f) print('', file=f) else: with open(file, 'w') as f: for i in range(4, len(visl)): print(visl[i], file=f) for c in range(0, len(visl[1]), col): if c == 0: print(visl[1][c:c+col+5], file=f) print(visl[2][c:c+col+5], file=f) print(visl[3][c:c+col+5], file=f) print('', file=f) else: print(" "+visl[1][c+5:c+5+col], file=f) print(" "+visl[2][c+5:c+5+col], file=f) print(" "+visl[3][c+5:c+5+col], file=f) print('', file=f) return out def process_ours(path): with open(path) as f: ours = f.read() print(ours) ours = [i.split('\u3000', maxsplit=1)[-1] for i in ours.splitlines()] print(ours) ours = ' '.join(ours) return ours def process_theirs(path): with open(path) as f: theirs = f.read().splitlines() i = 0 # print(theirs) while i < len(theirs): if theirs[i].startswith('會議記錄:'): break i+=1 theirs = theirs[i+1:] transcript = ' '.join(theirs) # for i in range(0, len(theirs), 4): # text = (theirs[i+2]) # transcript += text return transcript def compare_transcripts(ours_file, theirs_file, remove_punc, number_box): ours = process_ours(ours_file) theirs = process_theirs(theirs_file) output = io.StringIO() ali(theirs.replace(' ', ''), ours.replace(' ', ''), remove_punc=remove_punc, file=output, col=int(number_box)) return output.getvalue() custom_css = """ textarea[data-testid="textbox"] { font-family: monospace !important; } """ with gr.Blocks(title="Transcript Alignment Viewer",css=custom_css) as demo: gr.Markdown("## Transcript Alignment Viewer") gr.Markdown("上傳請確認你用的是 `不分段會議紀錄`") with gr.Row(equal_height=True): ours_file = gr.File(label="Our Transcript", file_types=[".txt"], scale=1) theirs_file = gr.File(label="Their Transcript", file_types=[".txt"], scale=1) with gr.Row(): compare_btn = gr.Button("Generate Alignment", scale=1) remove_punc = gr.Checkbox(label="Remove Punctuation", scale=1) gr.Markdown("**Column size:**") number_box = gr.Textbox(show_label=False,value="80",max_lines=1,scale=1) output_text = gr.Textbox( label="Alignment Output", lines=30, max_lines=100, show_copy_button=True, interactive=False, elem_id="mono" ) compare_btn.click( fn=compare_transcripts, inputs=[ours_file, theirs_file, remove_punc, number_box], outputs=output_text ) # demo.launch( # # css=""" # # #output-box textarea { # # font-family: monospace !important; # # white-space: pre !important; # # overflow-y: scroll; # # height: 70vh !important; # # } # # """ # ) if __name__ == "__main__": demo.launch()