import os import torch # CRITICAL: Redirect cache to temporary storage os.environ['TORCH_HOME'] = '/tmp/torch_cache' os.environ['HUB_DIR'] = '/tmp/torch_hub' os.environ['TMPDIR'] = '/tmp' torch.hub.set_dir('/tmp/torch_hub') import gradio as gr import torch from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration ) import re HF_USERNAME = "Tin113" # ----------------------------------------- BART_MODEL_REPO = f"{HF_USERNAME}/bart_model" VIT5_MODEL_REPO = f"{HF_USERNAME}/vit5_model" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Thiết bị sử dụng: {DEVICE}") # Tải các model try: print(f"Đang tải model BART từ {BART_MODEL_REPO}...") tokenizer_bart = AutoTokenizer.from_pretrained(BART_MODEL_REPO) model_bart = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_REPO).to(DEVICE) model_bart.eval() print("Tải model BART thành công.") except Exception as e: print(f"Lỗi khi tải model BART: {e}") model_bart, tokenizer_bart = None, None try: print(f"Đang tải model ViT5 từ {VIT5_MODEL_REPO}...") tokenizer_vit5 = T5Tokenizer.from_pretrained(VIT5_MODEL_REPO) model_vit5 = T5ForConditionalGeneration.from_pretrained(VIT5_MODEL_REPO).to(DEVICE) model_vit5.eval() print("Tải model ViT5 thành công.") except Exception as e: print(f"Lỗi khi tải model ViT5: {e}") model_vit5, tokenizer_vit5 = None, None def clean_text(text): if not isinstance(text, str): return "" return re.sub(r'\s+', ' ', text).strip() def correct_grammar(sentence, model_choice): if not sentence.strip(): return "Vui lòng nhập một câu." model, tokenizer, prefix = None, None, "" if model_choice == "BARTpho-syllable": if model_bart: model, tokenizer, prefix = model_bart, tokenizer_bart, "Fix: " else: return "Lỗi: Model BART không khả dụng. Vui lòng kiểm tra lại Space." elif model_choice == "ViT5-base": if model_vit5: model, tokenizer, prefix = model_vit5, tokenizer_vit5, "sửa lỗi: " else: return "Lỗi: Model ViT5 không khả dụng. Vui lòng kiểm tra lại Space." input_text = prefix + sentence input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True, padding=True).input_ids.to(DEVICE) with torch.no_grad(): outputs = model.generate(input_ids, max_length=276, num_beams=2, early_stopping=True, repetition_penalty=1.05, no_repeat_ngram_size=2) return clean_text(tokenizer.decode(outputs[0], skip_special_tokens=True)) description = """ Demo sửa lỗi chính tả tiếng Việt sử dụng hai model: BARTpho-syllable và ViT5-base. 1. Nhập câu lỗi vào ô bên dưới. 2. Chọn model bạn muốn dùng. 3. Nhấn "Submit" để xem kết quả. """ demo = gr.Interface( fn=correct_grammar, inputs=[ gr.Textbox(lines=5, label="Nhập câu tiếng Việt bị lỗi"), gr.Radio(choices=["BARTpho-syllable", "ViT5-base"], value="ViT5-base", label="Chọn Model") ], outputs=gr.Textbox(label="Câu đã được sửa"), title="Sửa lỗi chính tả Tiếng Việt", description=description, ) if __name__ == "__main__": demo.launch()