File size: 3,341 Bytes
4e5872e
 
 
 
 
 
 
 
 
030b185
 
 
 
 
 
 
 
 
 
 
f8bfbac
 
030b185
f8bfbac
 
030b185
f8bfbac
 
030b185
 
f8bfbac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030b185
 
f8bfbac
030b185
 
 
f8bfbac
030b185
f8bfbac
030b185
f8bfbac
 
030b185
f8bfbac
030b185
f8bfbac
 
030b185
f8bfbac
 
030b185
f8bfbac
030b185
 
f8bfbac
 
 
030b185
 
91a69e0
f8bfbac
 
 
030b185
 
 
 
 
 
f8bfbac
030b185
 
91a69e0
030b185
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch

# CRITICAL: Redirect cache to temporary storage
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['HUB_DIR'] = '/tmp/torch_hub'
os.environ['TMPDIR'] = '/tmp'
torch.hub.set_dir('/tmp/torch_hub')


import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    T5Tokenizer,
    T5ForConditionalGeneration
)
import re

HF_USERNAME = "Tin113"
# -----------------------------------------

BART_MODEL_REPO = f"{HF_USERNAME}/bart_model"
VIT5_MODEL_REPO = f"{HF_USERNAME}/vit5_model"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Thiết bị sử dụng: {DEVICE}")

# Tải các model
try:
    print(f"Đang tải model BART từ {BART_MODEL_REPO}...")
    tokenizer_bart = AutoTokenizer.from_pretrained(BART_MODEL_REPO)
    model_bart = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_REPO).to(DEVICE)
    model_bart.eval()
    print("Tải model BART thành công.")
except Exception as e:
    print(f"Lỗi khi tải model BART: {e}")
    model_bart, tokenizer_bart = None, None

try:
    print(f"Đang tải model ViT5 từ {VIT5_MODEL_REPO}...")
    tokenizer_vit5 = T5Tokenizer.from_pretrained(VIT5_MODEL_REPO)
    model_vit5 = T5ForConditionalGeneration.from_pretrained(VIT5_MODEL_REPO).to(DEVICE)
    model_vit5.eval()
    print("Tải model ViT5 thành công.")
except Exception as e:
    print(f"Lỗi khi tải model ViT5: {e}")
    model_vit5, tokenizer_vit5 = None, None

def clean_text(text):
    if not isinstance(text, str): return ""
    return re.sub(r'\s+', ' ', text).strip()

def correct_grammar(sentence, model_choice):
    if not sentence.strip(): return "Vui lòng nhập một câu."

    model, tokenizer, prefix = None, None, ""
    if model_choice == "BARTpho-syllable":
        if model_bart:
            model, tokenizer, prefix = model_bart, tokenizer_bart, "Fix: "
        else:
            return "Lỗi: Model BART không khả dụng. Vui lòng kiểm tra lại Space."
    elif model_choice == "ViT5-base":
        if model_vit5:
            model, tokenizer, prefix = model_vit5, tokenizer_vit5, "sửa lỗi: "
        else:
            return "Lỗi: Model ViT5 không khả dụng. Vui lòng kiểm tra lại Space."
    
    input_text = prefix + sentence
    input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True, padding=True).input_ids.to(DEVICE)

    with torch.no_grad():
        outputs = model.generate(input_ids, max_length=276, num_beams=2, early_stopping=True, repetition_penalty=1.05, no_repeat_ngram_size=2)
    
    return clean_text(tokenizer.decode(outputs[0], skip_special_tokens=True))

description = """
Demo sửa lỗi chính tả tiếng Việt sử dụng hai model: BARTpho-syllable và ViT5-base.
1. Nhập câu lỗi vào ô bên dưới.
2. Chọn model bạn muốn dùng.
3. Nhấn "Submit" để xem kết quả.
"""

demo = gr.Interface(
    fn=correct_grammar,
    inputs=[
        gr.Textbox(lines=5, label="Nhập câu tiếng Việt bị lỗi"),
        gr.Radio(choices=["BARTpho-syllable", "ViT5-base"], value="ViT5-base", label="Chọn Model")
    ],
    outputs=gr.Textbox(label="Câu đã được sửa"),
    title="Sửa lỗi chính tả Tiếng Việt",
    description=description,
)

if __name__ == "__main__":
    demo.launch()