import gc import logging import os import re import torch from cleantext import clean import gradio as gr from tqdm.auto import tqdm from transformers import pipeline logging.basicConfig(level=logging.INFO) logging.info(f"torch version:\t{torch.__version__}") # --- 1. ต้องประกาศชื่อ Model ไว้ตรงนี้ก่อน (ห้ามย้ายไปไว้ข้างล่าง) --- checker_model_name = "textattack/roberta-base-CoLA" corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis" # --- 2. เช็ค Device (ป้องกัน RuntimeError เรื่อง NVIDIA) --- device = 0 if torch.cuda.is_available() else -1 logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}") # --- 3. สร้าง Pipeline (ดึงตัวแปรจากข้อ 1 มาใช้) --- checker = pipeline( "text-classification", model=checker_model_name, device=device, ) corrector = pipeline( "text2text-generation", model=corrector_model_name, device=device, ) # --- ฟังก์ชันการทำงานอื่นๆ --- def split_text(text: str) -> list: sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text) sentence_batches = [] temp_batch = [] for sentence in sentences: temp_batch.append(sentence) if (len(temp_batch) >= 2 and len(temp_batch) <= 3) or sentence == sentences[-1]: sentence_batches.append(temp_batch) temp_batch = [] return sentence_batches def correct_text(text: str, separator: str = " ") -> str: sentence_batches = split_text(text) corrected_text = [] for batch in tqdm(sentence_batches, desc="correcting text.."): raw_text = " ".join(batch) results = checker(raw_text) # ตรวจสอบคุณภาพไวยากรณ์ if results[0]["label"] != "LABEL_1" or ( results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9 ): corrected_batch = corrector(raw_text) corrected_text.append(corrected_batch[0]["generated_text"]) else: corrected_text.append(raw_text) return separator.join(corrected_text) def update(text: str): text = clean(text[:4000], lower=False) return correct_text(text) # --- 4. Interface --- with gr.Blocks() as demo: gr.Markdown("#
Robust Grammar Correction
") with gr.Row(): inp = gr.Textbox(label="Input", placeholder="Enter text here...") out = gr.Textbox(label="Output", interactive=False) btn = gr.Button("Process") btn.click(fn=update, inputs=inp, outputs=out) demo.launch()