Tin113 commited on
Commit
030b185
·
verified ·
1 Parent(s): eaae444

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Nội dung file app.py
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForSeq2SeqLM,
8
+ T5Tokenizer,
9
+ T5ForConditionalGeneration
10
+ )
11
+ import re
12
+
13
+ # --- THAY ĐỔI CÁC THÔNG TIN SAU CHO ĐÚNG VỚI REPO CỦA BẠN ---
14
+ BART_MODEL_REPO = "Tin113/bart_model"
15
+ VIT5_MODEL_REPO = "Tin113/vit5_model"
16
+ # -------------------------------------------------------------
17
+
18
+ # Chọn thiết bị (GPU nếu có)
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ print(f"Bắt đầu tải model lên thiết bị: {DEVICE}...")
22
+
23
+ # Hàm tải model, với cơ chế thử lại để tránh lỗi tải tạm thời
24
+ def load_model(model_class, tokenizer_class, repo_id):
25
+ try:
26
+ tokenizer = tokenizer_class.from_pretrained(repo_id)
27
+ model = model_class.from_pretrained(repo_id).to(DEVICE)
28
+ model.eval()
29
+ print(f"Tải thành công model: {repo_id}")
30
+ return model, tokenizer
31
+ except Exception as e:
32
+ print(f"Lỗi khi tải model {repo_id}: {e}")
33
+ # Trả về None nếu có lỗi để xử lý ở giao diện
34
+ return None, None
35
+
36
+ # Tải các model
37
+ model_bart, tokenizer_bart = load_model(AutoModelForSeq2SeqLM, AutoTokenizer, BART_MODEL_REPO)
38
+ model_vit5, tokenizer_vit5 = load_model(T5ForConditionalGeneration, T5Tokenizer, VIT5_MODEL_REPO)
39
+
40
+ # Hàm clean text, lấy từ notebook của bạn
41
+ def clean_text(text):
42
+ if not isinstance(text, str):
43
+ return ""
44
+ return re.sub(r'\s+', ' ', text).strip()
45
+
46
+ # Hàm xử lý việc sửa lỗi
47
+ def correct_grammar(sentence, model_choice):
48
+ if not sentence.strip():
49
+ return "Vui lòng nhập một câu."
50
+
51
+ model = None
52
+ tokenizer = None
53
+ prefix = ""
54
+
55
+ if model_choice == "BARTpho-syllable":
56
+ if model_bart and tokenizer_bart:
57
+ model = model_bart
58
+ tokenizer = tokenizer_bart
59
+ prefix = "Fix: "
60
+ else:
61
+ return "Lỗi: Model BART không khả dụng. Vui lòng thử lại sau."
62
+
63
+ elif model_choice == "ViT5-base":
64
+ if model_vit5 and tokenizer_vit5:
65
+ model = model_vit5
66
+ tokenizer = tokenizer_vit5
67
+ prefix = "sửa lỗi: "
68
+ else:
69
+ return "Lỗi: Model ViT5 không khả dụng. Vui lòng thử lại sau."
70
+
71
+ input_text = prefix + sentence
72
+ input_ids = tokenizer(
73
+ input_text,
74
+ return_tensors="pt",
75
+ max_length=256,
76
+ truncation=True,
77
+ padding=True
78
+ ).input_ids.to(DEVICE)
79
+
80
+ with torch.no_grad():
81
+ outputs = model.generate(
82
+ input_ids,
83
+ max_length=256 + 20,
84
+ num_beams=2,
85
+ early_stopping=True,
86
+ repetition_penalty=1.05,
87
+ no_repeat_ngram_size=2
88
+ )
89
+
90
+ corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+ return clean_text(corrected_sentence)
92
+
93
+ # Ví dụ cho giao diện Gradio
94
+ examples = [
95
+ ["chương trỉnhnh được páht sóng vào lúc 19h", "ViT5-base"],
96
+ ["công nghề thônngg tin đáng phát chiển rất nhanh", "ViT5-base"],
97
+ ["Học hok tốt thì kho mak đc điểm cao.", "BARTpho-syllable"],
98
+ ["dù rất mệt nhưng anh ấy vẫn cố hoàn thành công việc", "BARTpho-syllable"],
99
+ ]
100
+
101
+ # Mô tả cho ứng dụng
102
+ description = """
103
+ Đây là ứng dụng demo cho việc sửa lỗi ngữ pháp tiếng Việt (Vietnamese Grammatical Error Correction).
104
+ Ứng dụng sử dụng hai model đã được fine-tune:
105
+ 1. **BARTpho-syllable**: Dựa trên kiến trúc BART, được tối ưu cho tiếng Việt ở cấp độ âm tiết.
106
+ 2. **ViT5-base**: Dựa trên kiến trúc T5, một model mạnh mẽ cho các tác vụ Text-to-Text.
107
+
108
+ **Cách sử dụng:**
109
+ 1. Nhập câu tiếng Việt có lỗi vào ô bên dưới.
110
+ 2. Chọn một trong hai model để thực hiện sửa lỗi.
111
+ 3. Nhấn "Submit" và xem kết quả.
112
+ """
113
+
114
+ # Tạo giao diện Gradio
115
+ demo = gr.Interface(
116
+ fn=correct_grammar,
117
+ inputs=[
118
+ gr.Textbox(lines=5, label="Nhập câu tiếng Việt bị lỗi"),
119
+ gr.Radio(
120
+ choices=["BARTpho-syllable", "ViT5-base"],
121
+ value="ViT5-base", # Model mặc định
122
+ label="Chọn Model"
123
+ )
124
+ ],
125
+ outputs=gr.Textbox(label="Câu đã được sửa"),
126
+ title="Sửa lỗi Ngữ pháp Tiếng Việt",
127
+ description=description,
128
+ examples=examples
129
+ )
130
+
131
+ # Chạy ứng dụng
132
+ if __name__ == "__main__":
133
+ demo.launch()