File size: 3,357 Bytes
43ef4da
9de1a48
819e96f
fadf4a0
43ef4da
fadf4a0
 
 
 
43ef4da
b07e91d
 
 
fadf4a0
9de1a48
 
fadf4a0
 
 
9de1a48
 
 
 
 
ab1bdfa
9de1a48
 
 
ab1bdfa
 
 
 
9de1a48
 
 
 
 
 
 
ab1bdfa
9de1a48
819e96f
 
f315b35
 
ab1bdfa
819e96f
f315b35
819e96f
 
ab1bdfa
819e96f
 
ab1bdfa
819e96f
ab1bdfa
9de1a48
ab1bdfa
 
 
 
fadf4a0
ab1bdfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fadf4a0
b07e91d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import difflib
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def load_model(checkpoint_path):
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)
    return model, tokenizer

checkpoint_path = "Diezu/Batpho_v2"
model, tokenizer = load_model(checkpoint_path)

def correct_spelling(text):
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(**inputs, max_length=512, num_beams=5, early_stopping=True)
    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

def find_spelling_errors(s1, s2):
    a = s1.split()
    b = s2.split()
    matcher = difflib.SequenceMatcher(None, a, b)
    highlighted_text = []
    corrected_words = {}
    
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'replace':
            incorrect_word = " ".join(a[i1:i2])
            correct_word = " ".join(b[j1:j2])
            highlighted_text.append(f'<button onclick="replaceWord(this, \"{correct_word}\")" style="color: red; text-decoration: underline; border: none; background: none;">{incorrect_word}</button>')
            corrected_words[incorrect_word] = correct_word
        elif tag == 'delete':
            highlighted_text.append(f'<span style="color: orange; text-decoration: line-through;">{" ".join(a[i1:i2])}</span>')
        elif tag == 'insert':
            highlighted_text.append(f'<span style="color: green; font-weight: bold;">{" ".join(b[j1:j2])}</span>')
        else:
            highlighted_text.append(" ".join(a[i1:i2]))
    
    return " ".join(highlighted_text), corrected_words

def process_text(text):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    corrected_sentences = []
    highlighted_sentences = []
    all_corrected_words = {}
    
    for sentence in sentences:
        if sentence.strip():
            corrected_text = correct_spelling(sentence)
            highlighted_text, corrected_words = find_spelling_errors(sentence, corrected_text)
            corrected_sentences.append(corrected_text)
            highlighted_sentences.append(highlighted_text)
            all_corrected_words.update(corrected_words)
    
    return "<br>".join(highlighted_sentences), "\n".join(corrected_sentences), all_corrected_words

def apply_full_correction(text, corrected_words):
    for incorrect, correct in corrected_words.items():
        text = text.replace(incorrect, f'<span style="color: blue; font-weight: bold;">{correct}</span>')
    return text

demo = gr.Blocks()
with demo:
    input_text = gr.Textbox(placeholder="Nhập văn bản có lỗi chính tả...")
    output_html = gr.HTML()
    corrected_textbox = gr.Textbox()
    corrected_words_state = gr.State({})
    
    btn_check = gr.Button("Phát hiện lỗi")
    btn_correct_all = gr.Button("Sửa toàn bộ lỗi")
    
    def update_text(text):
        highlighted, corrected, corrected_words = process_text(text)
        return highlighted, corrected, corrected_words
    
    btn_check.click(update_text, inputs=input_text, outputs=[output_html, corrected_textbox, corrected_words_state])
    btn_correct_all.click(apply_full_correction, inputs=[input_text, corrected_words_state], outputs=output_html)

demo.launch()