Test_gradio / app.py
Diezu's picture
Update app.py
b07e91d verified
import gradio as gr
import difflib
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def load_model(checkpoint_path):
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)
return model, tokenizer
checkpoint_path = "Diezu/Batpho_v2"
model, tokenizer = load_model(checkpoint_path)
def correct_spelling(text):
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(**inputs, max_length=512, num_beams=5, early_stopping=True)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
def find_spelling_errors(s1, s2):
a = s1.split()
b = s2.split()
matcher = difflib.SequenceMatcher(None, a, b)
highlighted_text = []
corrected_words = {}
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'replace':
incorrect_word = " ".join(a[i1:i2])
correct_word = " ".join(b[j1:j2])
highlighted_text.append(f'<button onclick="replaceWord(this, \"{correct_word}\")" style="color: red; text-decoration: underline; border: none; background: none;">{incorrect_word}</button>')
corrected_words[incorrect_word] = correct_word
elif tag == 'delete':
highlighted_text.append(f'<span style="color: orange; text-decoration: line-through;">{" ".join(a[i1:i2])}</span>')
elif tag == 'insert':
highlighted_text.append(f'<span style="color: green; font-weight: bold;">{" ".join(b[j1:j2])}</span>')
else:
highlighted_text.append(" ".join(a[i1:i2]))
return " ".join(highlighted_text), corrected_words
def process_text(text):
sentences = re.split(r'(?<=[.!?])\s+', text)
corrected_sentences = []
highlighted_sentences = []
all_corrected_words = {}
for sentence in sentences:
if sentence.strip():
corrected_text = correct_spelling(sentence)
highlighted_text, corrected_words = find_spelling_errors(sentence, corrected_text)
corrected_sentences.append(corrected_text)
highlighted_sentences.append(highlighted_text)
all_corrected_words.update(corrected_words)
return "<br>".join(highlighted_sentences), "\n".join(corrected_sentences), all_corrected_words
def apply_full_correction(text, corrected_words):
for incorrect, correct in corrected_words.items():
text = text.replace(incorrect, f'<span style="color: blue; font-weight: bold;">{correct}</span>')
return text
demo = gr.Blocks()
with demo:
input_text = gr.Textbox(placeholder="Nhập văn bản có lỗi chính tả...")
output_html = gr.HTML()
corrected_textbox = gr.Textbox()
corrected_words_state = gr.State({})
btn_check = gr.Button("Phát hiện lỗi")
btn_correct_all = gr.Button("Sửa toàn bộ lỗi")
def update_text(text):
highlighted, corrected, corrected_words = process_text(text)
return highlighted, corrected, corrected_words
btn_check.click(update_text, inputs=input_text, outputs=[output_html, corrected_textbox, corrected_words_state])
btn_correct_all.click(apply_full_correction, inputs=[input_text, corrected_words_state], outputs=output_html)
demo.launch()