File size: 2,704 Bytes
c0710a3
 
 
 
 
 
 
 
 
 
 
 
 
8a4038f
 
 
 
 
84e7748
8a4038f
c0710a3
8a4038f
c0710a3
 
84e7748
8a4038f
84e7748
 
 
 
8a4038f
c0710a3
 
8a4038f
c0710a3
 
 
 
 
 
8a4038f
c0710a3
 
 
 
 
 
 
8a4038f
c0710a3
 
8a4038f
 
c0710a3
 
 
 
 
 
 
 
 
 
 
 
 
8a4038f
c0710a3
8a4038f
c0710a3
8a4038f
 
c0710a3
 
 
8a4038f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gc
import logging
import os
import re
import torch
from cleantext import clean
import gradio as gr
from tqdm.auto import tqdm
from transformers import pipeline

logging.basicConfig(level=logging.INFO)
logging.info(f"torch version:\t{torch.__version__}")

# --- 1. ต้องประกาศชื่อ Model ไว้ตรงนี้ก่อน (ห้ามย้ายไปไว้ข้างล่าง) ---
checker_model_name = "textattack/roberta-base-CoLA"
corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"

# --- 2. เช็ค Device (ป้องกัน RuntimeError เรื่อง NVIDIA) ---
device = 0 if torch.cuda.is_available() else -1
logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}")

# --- 3. สร้าง Pipeline (ดึงตัวแปรจากข้อ 1 มาใช้) ---
checker = pipeline(
    "text-classification",
    model=checker_model_name,
    device=device,
)
corrector = pipeline(
    "text2text-generation",
    model=corrector_model_name,
    device=device,
)

# --- ฟังก์ชันการทำงานอื่นๆ ---
def split_text(text: str) -> list:
    sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
    sentence_batches = []
    temp_batch = []
    for sentence in sentences:
        temp_batch.append(sentence)
        if (len(temp_batch) >= 2 and len(temp_batch) <= 3) or sentence == sentences[-1]:
            sentence_batches.append(temp_batch)
            temp_batch = []
    return sentence_batches

def correct_text(text: str, separator: str = " ") -> str:
    sentence_batches = split_text(text)
    corrected_text = []
    for batch in tqdm(sentence_batches, desc="correcting text.."):
        raw_text = " ".join(batch)
        results = checker(raw_text)
        
        # ตรวจสอบคุณภาพไวยากรณ์
        if results[0]["label"] != "LABEL_1" or (
            results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
        ):
            corrected_batch = corrector(raw_text)
            corrected_text.append(corrected_batch[0]["generated_text"])
        else:
            corrected_text.append(raw_text)
    return separator.join(corrected_text)

def update(text: str):
    text = clean(text[:4000], lower=False)
    return correct_text(text)

# --- 4. Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# <center>Robust Grammar Correction</center>")
    with gr.Row():
        inp = gr.Textbox(label="Input", placeholder="Enter text here...")
        out = gr.Textbox(label="Output", interactive=False)
    btn = gr.Button("Process")
    btn.click(fn=update, inputs=inp, outputs=out)

demo.launch()