Spaces:
Sleeping
Sleeping
File size: 3,341 Bytes
4e5872e 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 f8bfbac 030b185 91a69e0 f8bfbac 030b185 f8bfbac 030b185 91a69e0 030b185 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import os
import torch
# CRITICAL: Redirect cache to temporary storage
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['HUB_DIR'] = '/tmp/torch_hub'
os.environ['TMPDIR'] = '/tmp'
torch.hub.set_dir('/tmp/torch_hub')
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
T5Tokenizer,
T5ForConditionalGeneration
)
import re
HF_USERNAME = "Tin113"
# -----------------------------------------
BART_MODEL_REPO = f"{HF_USERNAME}/bart_model"
VIT5_MODEL_REPO = f"{HF_USERNAME}/vit5_model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Thiết bị sử dụng: {DEVICE}")
# Tải các model
try:
print(f"Đang tải model BART từ {BART_MODEL_REPO}...")
tokenizer_bart = AutoTokenizer.from_pretrained(BART_MODEL_REPO)
model_bart = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_REPO).to(DEVICE)
model_bart.eval()
print("Tải model BART thành công.")
except Exception as e:
print(f"Lỗi khi tải model BART: {e}")
model_bart, tokenizer_bart = None, None
try:
print(f"Đang tải model ViT5 từ {VIT5_MODEL_REPO}...")
tokenizer_vit5 = T5Tokenizer.from_pretrained(VIT5_MODEL_REPO)
model_vit5 = T5ForConditionalGeneration.from_pretrained(VIT5_MODEL_REPO).to(DEVICE)
model_vit5.eval()
print("Tải model ViT5 thành công.")
except Exception as e:
print(f"Lỗi khi tải model ViT5: {e}")
model_vit5, tokenizer_vit5 = None, None
def clean_text(text):
if not isinstance(text, str): return ""
return re.sub(r'\s+', ' ', text).strip()
def correct_grammar(sentence, model_choice):
if not sentence.strip(): return "Vui lòng nhập một câu."
model, tokenizer, prefix = None, None, ""
if model_choice == "BARTpho-syllable":
if model_bart:
model, tokenizer, prefix = model_bart, tokenizer_bart, "Fix: "
else:
return "Lỗi: Model BART không khả dụng. Vui lòng kiểm tra lại Space."
elif model_choice == "ViT5-base":
if model_vit5:
model, tokenizer, prefix = model_vit5, tokenizer_vit5, "sửa lỗi: "
else:
return "Lỗi: Model ViT5 không khả dụng. Vui lòng kiểm tra lại Space."
input_text = prefix + sentence
input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True, padding=True).input_ids.to(DEVICE)
with torch.no_grad():
outputs = model.generate(input_ids, max_length=276, num_beams=2, early_stopping=True, repetition_penalty=1.05, no_repeat_ngram_size=2)
return clean_text(tokenizer.decode(outputs[0], skip_special_tokens=True))
description = """
Demo sửa lỗi chính tả tiếng Việt sử dụng hai model: BARTpho-syllable và ViT5-base.
1. Nhập câu lỗi vào ô bên dưới.
2. Chọn model bạn muốn dùng.
3. Nhấn "Submit" để xem kết quả.
"""
demo = gr.Interface(
fn=correct_grammar,
inputs=[
gr.Textbox(lines=5, label="Nhập câu tiếng Việt bị lỗi"),
gr.Radio(choices=["BARTpho-syllable", "ViT5-base"], value="ViT5-base", label="Chọn Model")
],
outputs=gr.Textbox(label="Câu đã được sửa"),
title="Sửa lỗi chính tả Tiếng Việt",
description=description,
)
if __name__ == "__main__":
demo.launch() |