9Dome commited on
Commit
8a4038f
·
verified ·
1 Parent(s): 4d86c95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -63
app.py CHANGED
@@ -2,76 +2,55 @@ import gc
2
  import logging
3
  import os
4
  import re
5
-
6
-
7
- import spaces
8
-
9
  import torch
10
  from cleantext import clean
11
  import gradio as gr
12
  from tqdm.auto import tqdm
13
  from transformers import pipeline
14
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logging.info(f"torch version:\t{torch.__version__}")
18
 
 
 
 
 
 
19
  device = 0 if torch.cuda.is_available() else -1
 
20
 
 
21
  checker = pipeline(
22
  "text-classification",
23
  model=checker_model_name,
24
- device=device, # แก้จาก device_map="cuda" เป็น device
25
  )
26
-
27
  corrector = pipeline(
28
  "text2text-generation",
29
  model=corrector_model_name,
30
- device=device, # แก้จาก device_map="cuda" เป็น device
31
- )
32
-
33
- corrector = pipeline(
34
- "text2text-generation",
35
- corrector_model_name,
36
- device_map="cuda",
37
  )
38
 
 
39
  def split_text(text: str) -> list:
40
- # Split the text into sentences using regex
41
  sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
42
-
43
- # Initialize lists for batching
44
  sentence_batches = []
45
  temp_batch = []
46
-
47
- # Create batches of 2-3 sentences
48
  for sentence in sentences:
49
  temp_batch.append(sentence)
50
- if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
51
  sentence_batches.append(temp_batch)
52
  temp_batch = []
53
-
54
  return sentence_batches
55
 
56
-
57
  def correct_text(text: str, separator: str = " ") -> str:
58
-
59
- # Split the text into sentence batches
60
  sentence_batches = split_text(text)
61
-
62
- # Initialize a list to store the corrected text
63
  corrected_text = []
64
-
65
- # Process each batch
66
- for batch in tqdm(
67
- sentence_batches, total=len(sentence_batches), desc="correcting text.."
68
- ):
69
  raw_text = " ".join(batch)
70
-
71
- # Check grammar quality
72
  results = checker(raw_text)
73
-
74
- # Correct text if needed
75
  if results[0]["label"] != "LABEL_1" or (
76
  results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
77
  ):
@@ -79,42 +58,19 @@ def correct_text(text: str, separator: str = " ") -> str:
79
  corrected_text.append(corrected_batch[0]["generated_text"])
80
  else:
81
  corrected_text.append(raw_text)
82
-
83
- # Join the corrected text
84
  return separator.join(corrected_text)
85
 
86
-
87
  def update(text: str):
88
- # Clean and truncate input text
89
  text = clean(text[:4000], lower=False)
90
  return correct_text(text)
91
 
92
-
93
- # Create the Gradio interface
94
  with gr.Blocks() as demo:
95
- gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>")
96
- gr.Markdown(
97
- "**Instructions:** Enter the text you want to correct in the textbox below (_text will be truncated to 4000 characters_). Click 'Process' to run."
98
- )
99
- gr.Markdown(
100
- """Models:
101
- - `textattack/roberta-base-CoLA` for grammar quality detection
102
- - `pszemraj/flan-t5-large-grammar-synthesis` for grammar correction
103
- """
104
- )
105
  with gr.Row():
106
- inp = gr.Textbox(
107
- label="input",
108
- placeholder="Enter text to check & correct",
109
- value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.",
110
- )
111
- out = gr.Textbox(label="output", interactive=False)
112
  btn = gr.Button("Process")
113
  btn.click(fn=update, inputs=inp, outputs=out)
114
- gr.Markdown("---")
115
- gr.Markdown(
116
- "- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
117
- )
118
 
119
- # Launch the demo
120
- demo.launch(debug=True)
 
2
  import logging
3
  import os
4
  import re
 
 
 
 
5
  import torch
6
  from cleantext import clean
7
  import gradio as gr
8
  from tqdm.auto import tqdm
9
  from transformers import pipeline
 
10
 
11
  logging.basicConfig(level=logging.INFO)
12
  logging.info(f"torch version:\t{torch.__version__}")
13
 
14
+ # --- 1. ต้องประกาศชื่อ Model ไว้ตรงนี้ก่อน (ห้ามย้ายไปไว้ข้างล่าง) ---
15
+ checker_model_name = "textattack/roberta-base-CoLA"
16
+ corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
17
+
18
+ # --- 2. เช็ค Device (ป้องกัน RuntimeError เรื่อง NVIDIA) ---
19
  device = 0 if torch.cuda.is_available() else -1
20
+ logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}")
21
 
22
+ # --- 3. สร้าง Pipeline (ดึงตัวแปรจากข้อ 1 มาใช้) ---
23
  checker = pipeline(
24
  "text-classification",
25
  model=checker_model_name,
26
+ device=device,
27
  )
 
28
  corrector = pipeline(
29
  "text2text-generation",
30
  model=corrector_model_name,
31
+ device=device,
 
 
 
 
 
 
32
  )
33
 
34
+ # --- ฟังก์ชันการทำงานอื่นๆ ---
35
  def split_text(text: str) -> list:
 
36
  sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
 
 
37
  sentence_batches = []
38
  temp_batch = []
 
 
39
  for sentence in sentences:
40
  temp_batch.append(sentence)
41
+ if (len(temp_batch) >= 2 and len(temp_batch) <= 3) or sentence == sentences[-1]:
42
  sentence_batches.append(temp_batch)
43
  temp_batch = []
 
44
  return sentence_batches
45
 
 
46
  def correct_text(text: str, separator: str = " ") -> str:
 
 
47
  sentence_batches = split_text(text)
 
 
48
  corrected_text = []
49
+ for batch in tqdm(sentence_batches, desc="correcting text.."):
 
 
 
 
50
  raw_text = " ".join(batch)
 
 
51
  results = checker(raw_text)
52
+
53
+ # ตรวจสอบคุณภาพไวยากรณ์
54
  if results[0]["label"] != "LABEL_1" or (
55
  results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
56
  ):
 
58
  corrected_text.append(corrected_batch[0]["generated_text"])
59
  else:
60
  corrected_text.append(raw_text)
 
 
61
  return separator.join(corrected_text)
62
 
 
63
  def update(text: str):
 
64
  text = clean(text[:4000], lower=False)
65
  return correct_text(text)
66
 
67
+ # --- 4. Interface ---
 
68
  with gr.Blocks() as demo:
69
+ gr.Markdown("# <center>Robust Grammar Correction</center>")
 
 
 
 
 
 
 
 
 
70
  with gr.Row():
71
+ inp = gr.Textbox(label="Input", placeholder="Enter text here...")
72
+ out = gr.Textbox(label="Output", interactive=False)
 
 
 
 
73
  btn = gr.Button("Process")
74
  btn.click(fn=update, inputs=inp, outputs=out)
 
 
 
 
75
 
76
+ demo.launch()