| | """
|
| | GRADIO UI - VIETNAMESE TO ENGLISH TRANSLATION
|
| | Modern web interface for Transformer translation model
|
| | """
|
| |
|
| | import glob
|
| | import importlib
|
| | import os
|
| | import pickle
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| | import gradio as gr
|
| | import torch
|
| | from shared_vocab_utils import (
|
| | create_shared_vocab_wrapper,
|
| | load_shared_vocab_info,
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| | if str(PROJECT_ROOT) not in sys.path:
|
| | sys.path.append(str(PROJECT_ROOT))
|
| | SRC_DIR = PROJECT_ROOT / 'src'
|
| | if str(SRC_DIR) not in sys.path:
|
| | sys.path.append(str(SRC_DIR))
|
| |
|
| | DATA_PROCESSED_DIR = PROJECT_ROOT / 'data' / 'processed'
|
| | CHECKPOINT_DIR = PROJECT_ROOT / 'checkpoints'
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | data_preprocessing = importlib.import_module('src.data_preprocessing')
|
| | sys.modules.setdefault('data_preprocessing', data_preprocessing)
|
| | Vocabulary = data_preprocessing.Vocabulary
|
| | clean_text = data_preprocessing.clean_text
|
| |
|
| | complete_transformer = importlib.import_module('src.complete_transformer')
|
| | create_model = complete_transformer.create_model
|
| |
|
| | inference_module = importlib.import_module('src.inference_evaluation')
|
| | translate_sentence = inference_module.translate_sentence
|
| | beam_search_decode = inference_module.beam_search_decode
|
| | greedy_decode = inference_module.greedy_decode
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class TranslationModel:
|
| | """Wrapper class để quản lý model và vocabularies"""
|
| |
|
| | def __init__(self):
|
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| | self.model = None
|
| | self.vi_vocab = None
|
| | self.en_vocab = None
|
| | self.model_info = {}
|
| |
|
| | def load_latest_model(self):
|
| | """Load model mới nhất từ checkpoints"""
|
| | print("🔄 Loading model và vocabularies...")
|
| |
|
| |
|
| | try:
|
| | vocab_info = load_shared_vocab_info()
|
| | self.vi_vocab, self.en_vocab = create_shared_vocab_wrapper()
|
| | self.pad_idx = vocab_info.get("pad_id", 0)
|
| | self.sos_idx = vocab_info.get("sos_id", 1)
|
| | self.eos_idx = vocab_info.get("eos_id", 2)
|
| | vocab_size = vocab_info.get("vocab_size")
|
| | print(f"✓ Loaded shared vocabulary (size={vocab_size})")
|
| | except Exception as shared_err:
|
| | print(f"⚠️ Shared vocab load failed, fallback to pickled vocabs: {shared_err}")
|
| |
|
| | vi_vocab_path = DATA_PROCESSED_DIR / 'vi_vocab.pkl'
|
| | en_vocab_path = DATA_PROCESSED_DIR / 'en_vocab.pkl'
|
| | try:
|
| | with open(vi_vocab_path, 'rb') as f:
|
| | self.vi_vocab = pickle.load(f)
|
| | with open(en_vocab_path, 'rb') as f:
|
| | self.en_vocab = pickle.load(f)
|
| | self.pad_idx = getattr(self.vi_vocab, "PAD_IDX", 0)
|
| | vocab_size = len(self.vi_vocab)
|
| | print(f"✓ Loaded vocabularies (VI: {len(self.vi_vocab)}, EN: {len(self.en_vocab)})")
|
| | except Exception as e:
|
| | raise Exception(f"❌ Không thể load vocabularies (shared & pickled đều lỗi): {e}")
|
| |
|
| |
|
| | checkpoint_dir = CHECKPOINT_DIR
|
| | best_model_path = checkpoint_dir / 'best_model.pt'
|
| |
|
| |
|
| | if os.path.exists(best_model_path):
|
| | checkpoint_path = best_model_path
|
| | print(f"✓ Found best model: {best_model_path}")
|
| | else:
|
| |
|
| | checkpoints = glob.glob(str(checkpoint_dir / 'checkpoint_epoch_*.pt'))
|
| | if not checkpoints:
|
| | raise Exception(f"❌ Không tìm thấy checkpoint nào trong {checkpoint_dir}")
|
| |
|
| |
|
| | checkpoint_path = max(checkpoints, key=os.path.getmtime)
|
| | print(f"✓ Found latest checkpoint: {checkpoint_path}")
|
| |
|
| |
|
| | checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| |
|
| |
|
| | print("🔨 Creating model...")
|
| |
|
| | path_str = str(checkpoint_path).lower()
|
| | model_size = 'custom_25m'
|
| | for size in ['custom_25m', 'medium', 'base', 'small', 'tiny', 'large']:
|
| | if size in path_str or ('25m' in path_str and size == 'custom_25m'):
|
| | model_size = size
|
| | break
|
| |
|
| | self.model, model_config = create_model(
|
| | src_vocab_size=len(self.vi_vocab),
|
| | tgt_vocab_size=len(self.en_vocab),
|
| | model_size=model_size,
|
| | pad_idx=getattr(self, "pad_idx", 0),
|
| | use_shared_vocab=True,
|
| | use_weight_tying=True,
|
| | )
|
| |
|
| |
|
| | self.model.load_state_dict(checkpoint['model_state_dict'])
|
| | self.model = self.model.to(self.device)
|
| | self.model.eval()
|
| |
|
| |
|
| | self.model_info = {
|
| | 'checkpoint': checkpoint_path,
|
| | 'epoch': checkpoint.get('epoch', 'N/A'),
|
| | 'val_loss': checkpoint.get('val_loss', 'N/A'),
|
| | 'val_ppl': checkpoint.get('val_ppl', 'N/A'),
|
| | 'model_size': model_size,
|
| | 'device': str(self.device),
|
| | 'parameters': sum(p.numel() for p in self.model.parameters()),
|
| | 'vocab_size': len(self.vi_vocab),
|
| | }
|
| |
|
| | print(f"✓ Model loaded successfully!")
|
| | print(f" - Epoch: {self.model_info['epoch']}")
|
| | print(f" - Val Loss: {self.model_info['val_loss']}")
|
| | print(f" - Device: {self.model_info['device']}")
|
| |
|
| | return self.model_info
|
| |
|
| | def translate(self, text, direction='vi2en', use_beam_search=True, beam_size=5):
|
| | """Dịch một câu theo hướng chỉ định"""
|
| | if not self.model or not self.vi_vocab or not self.en_vocab:
|
| | return "❌ Model chưa được load. Vui lòng reload model."
|
| |
|
| | try:
|
| |
|
| | src_lang = 'en' if direction == 'en2vi' else 'vi'
|
| |
|
| |
|
| | text = clean_text(text, src_lang)
|
| |
|
| | if not text.strip():
|
| | lang_name = "tiếng Anh" if direction == 'en2vi' else "tiếng Việt"
|
| | return f"⚠️ Vui lòng nhập văn bản {lang_name}."
|
| |
|
| |
|
| |
|
| |
|
| | translation = translate_sentence(
|
| | self.model, text, self.vi_vocab, self.en_vocab,
|
| | self.device, use_beam_search, beam_size,
|
| | src_lang=src_lang,
|
| | repetition_penalty=1.3,
|
| | no_repeat_ngram_size=3
|
| | )
|
| |
|
| | return translation
|
| |
|
| | except Exception as e:
|
| | import traceback
|
| | error_detail = traceback.format_exc()
|
| | return f"❌ Lỗi khi dịch: {str(e)}\n\nChi tiết:\n{error_detail}"
|
| |
|
| |
|
| | translator = TranslationModel()
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def translate_text(text, direction, use_beam_search, beam_size):
|
| | """Wrapper function cho Gradio"""
|
| | return translator.translate(text, direction, use_beam_search, beam_size)
|
| |
|
| | def reload_model():
|
| | """Reload model mới nhất"""
|
| | try:
|
| | info = translator.load_latest_model()
|
| | return f"""
|
| | ✅ **Model đã được load thành công!**
|
| |
|
| | 📊 **Thông tin model:**
|
| | - Checkpoint: `{info['checkpoint']}`
|
| | - Epoch: {info['epoch']}
|
| | - Validation Loss: {f"{info['val_loss']:.4f}" if isinstance(info['val_loss'], float) else info['val_loss']}
|
| | - Model Size: {info['model_size'].upper()}
|
| | - Device: {info['device']}
|
| | - Parameters: {info['parameters']:,}
|
| | - Vocab size: {info.get('vocab_size', 'N/A')}
|
| | """
|
| | except Exception as e:
|
| | return f"❌ Lỗi khi load model: {str(e)}"
|
| |
|
| | def get_model_info():
|
| | """Lấy thông tin model hiện tại"""
|
| | if not translator.model:
|
| | return "⚠️ Model chưa được load. Click 'Reload Model' để load model."
|
| |
|
| | info = translator.model_info
|
| | return f"""
|
| | 📊 **Model hiện tại:**
|
| | - Checkpoint: `{info.get('checkpoint', 'N/A')}`
|
| | - Epoch: {info.get('epoch', 'N/A')}
|
| | - Validation Loss: {f"{info['val_loss']:.4f}" if isinstance(info.get('val_loss'), float) else info.get('val_loss', 'N/A')}
|
| | - Model Size: {str(info.get('model_size', 'N/A')).upper()}
|
| | - Device: {info.get('device', 'N/A')}
|
| | - Parameters: {info.get('parameters', 0):,}
|
| | - Vocab size: {info.get('vocab_size', 'N/A')}
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def create_app():
|
| | """Tạo Gradio app"""
|
| |
|
| | with gr.Blocks(
|
| | title="Bidirectional Translation: English ↔ Vietnamese"
|
| | ) as app:
|
| |
|
| |
|
| | gr.Markdown("""
|
| | # 🌐 Bidirectional Translation: English ↔ Vietnamese
|
| | ### Powered by Transformer Neural Network
|
| |
|
| | Dịch văn bản hai chiều giữa **Tiếng Anh** và **Tiếng Việt** sử dụng mô hình Transformer được train từ đầu.
|
| | """)
|
| |
|
| |
|
| | with gr.Row():
|
| | with gr.Column(scale=1):
|
| |
|
| | gr.Markdown("### 🔄 Chọn hướng dịch")
|
| | direction = gr.Radio(
|
| | choices=[
|
| | ("Vietnamese → English", "vi2en"),
|
| | ("English → Vietnamese", "en2vi")
|
| | ],
|
| | value="vi2en",
|
| | label="Hướng dịch",
|
| | info="Chọn ngôn ngữ đầu vào và đầu ra"
|
| | )
|
| |
|
| |
|
| | gr.Markdown("### 📝 Nhập văn bản")
|
| | input_text = gr.Textbox(
|
| | label="Văn bản đầu vào",
|
| | placeholder="Nhập câu cần dịch...\nVí dụ: Xin chào, tôi là sinh viên.\nHoặc: Hello, I am a student.",
|
| | lines=5,
|
| | max_lines=10
|
| | )
|
| |
|
| |
|
| | with gr.Row():
|
| | use_beam_search = gr.Checkbox(
|
| | label="🔍 Beam Search",
|
| | value=True,
|
| | info="Kết quả tốt hơn nhưng chậm hơn"
|
| | )
|
| | beam_size = gr.Slider(
|
| | minimum=1,
|
| | maximum=10,
|
| | value=5,
|
| | step=1,
|
| | label="Beam Size",
|
| | info="Số candidates (cao hơn = tốt hơn nhưng chậm hơn)"
|
| | )
|
| |
|
| |
|
| | with gr.Row():
|
| | translate_btn = gr.Button("🚀 Dịch", variant="primary", size="lg")
|
| | clear_btn = gr.ClearButton([input_text], value="🗑️ Xóa")
|
| |
|
| | with gr.Column(scale=1):
|
| |
|
| | gr.Markdown("### 🎯 Kết quả dịch")
|
| | output_text = gr.Textbox(
|
| | label="Văn bản đã dịch",
|
| | placeholder="Kết quả dịch sẽ hiển thị ở đây...",
|
| | lines=5,
|
| | max_lines=10,
|
| | interactive=False
|
| | )
|
| |
|
| |
|
| | gr.Markdown("### 💡 Ví dụ")
|
| | gr.Examples(
|
| | examples=[
|
| | ["vi2en", "Xin chào, tôi tên là Nam.", True, 5],
|
| | ["vi2en", "Hôm nay thời tiết đẹp.", True, 5],
|
| | ["vi2en", "Tôi đang học tiếng Anh.", True, 5],
|
| | ["vi2en", "Bạn khỏe không?", True, 5],
|
| | ["en2vi", "Hello, my name is Nam.", True, 5],
|
| | ["en2vi", "The weather is beautiful today.", True, 5],
|
| | ["en2vi", "I am learning English.", True, 5],
|
| | ["en2vi", "How are you?", True, 5],
|
| | ],
|
| | inputs=[direction, input_text, use_beam_search, beam_size],
|
| | outputs=output_text,
|
| | fn=translate_text,
|
| | cache_examples=False,
|
| | )
|
| |
|
| |
|
| | with gr.Accordion("⚙️ Thông tin Model & Cài đặt", open=False):
|
| | model_info_display = gr.Markdown(get_model_info())
|
| |
|
| | with gr.Row():
|
| | reload_btn = gr.Button("🔄 Reload Model", variant="secondary")
|
| | info_btn = gr.Button("ℹ️ Xem thông tin", variant="secondary")
|
| |
|
| | reload_output = gr.Markdown()
|
| |
|
| |
|
| | translate_btn.click(
|
| | fn=translate_text,
|
| | inputs=[input_text, direction, use_beam_search, beam_size],
|
| | outputs=output_text
|
| | )
|
| |
|
| | reload_btn.click(
|
| | fn=reload_model,
|
| | outputs=reload_output
|
| | )
|
| |
|
| | info_btn.click(
|
| | fn=get_model_info,
|
| | outputs=model_info_display
|
| | )
|
| |
|
| |
|
| | gr.Markdown("""
|
| | ---
|
| | ### 📚 Hướng dẫn sử dụng
|
| | 1. **Chọn hướng dịch**: Vietnamese → English hoặc English → Vietnamese
|
| | 2. **Nhập văn bản** vào ô bên trái
|
| | 3. **Chọn phương pháp dịch**:
|
| | - ✅ Beam Search: Chất lượng cao hơn (khuyên dùng)
|
| | - ❌ Greedy Search: Nhanh hơn nhưng kém chính xác hơn
|
| | 4. **Click "Dịch"** để xem kết quả
|
| | 5. **Reload Model** nếu bạn vừa train xong model mới
|
| |
|
| | ### 🔧 Tips
|
| | - Model hỗ trợ dịch hai chiều: Việt ↔ Anh
|
| | - Beam size càng cao thì kết quả càng tốt nhưng chậm hơn (khuyên dùng 5)
|
| | - Câu ngắn dịch nhanh hơn câu dài
|
| | - Model hoạt động tốt nhất với câu có độ dài 5-20 từ
|
| | - Model sử dụng shared vocabulary và checkpoint hiện có, không cần train lại
|
| | """)
|
| |
|
| | return app
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | print("="*70)
|
| | print("KHỞI ĐỘNG TRANSLATION WEB APP")
|
| | print("="*70)
|
| |
|
| |
|
| | try:
|
| | translator.load_latest_model()
|
| | except Exception as e:
|
| | print(f"⚠️ Warning: {e}")
|
| | print("💡 Bạn có thể reload model sau trong giao diện web.")
|
| |
|
| |
|
| | app = create_app()
|
| |
|
| | print("\n" + "="*70)
|
| | print("🚀 LAUNCHING WEB APP...")
|
| | print("="*70)
|
| |
|
| | import os
|
| | server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| |
|
| | app.launch(
|
| | server_name="0.0.0.0",
|
| | server_port=server_port,
|
| | share=False,
|
| | inbrowser=True,
|
| | ) |