vi-en-transformer-25m / src /shared_vocab_utils.py
Cong123779's picture
Upload model source code
51b3b77 verified
"""
UTILITY FUNCTIONS CHO SHARED VOCABULARY
Helper functions để load và sử dụng shared tokenizer
"""
import json
import pickle
import re
from pathlib import Path
from tokenizers import Tokenizer
from typing import Tuple, Optional
# ============================================================================
# PATH CONFIGURATION
# ============================================================================
PROJECT_ROOT = Path(__file__).resolve().parent.parent
PROCESSED_DATA_DIR = PROJECT_ROOT / 'data' / 'processed'
# ============================================================================
# LOAD SHARED VOCABULARY
# ============================================================================
def load_shared_tokenizer(tokenizer_path: Optional[Path] = None) -> Tokenizer:
"""
Load shared tokenizer từ file
Args:
tokenizer_path: Đường dẫn đến tokenizer_shared.json (mặc định tự tìm)
Returns:
tokenizer: Loaded tokenizer
"""
if tokenizer_path is None:
tokenizer_path = PROCESSED_DATA_DIR / 'tokenizer_shared.json'
if not tokenizer_path.exists():
raise FileNotFoundError(
f"Không tìm thấy tokenizer_shared.json tại {tokenizer_path}!\n"
f"Vui lòng chạy: python src/1_build_shared_vocab.py trước"
)
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
def load_shared_vocab_info(info_path: Optional[Path] = None) -> dict:
"""
Load thông tin về shared vocabulary
Args:
info_path: Đường dẫn đến shared_vocab_info.json (mặc định tự tìm)
Returns:
info: Dictionary chứa vocab_size, sos_id, eos_id, pad_id, unk_id
"""
if info_path is None:
info_path = PROCESSED_DATA_DIR / 'shared_vocab_info.json'
if not info_path.exists():
# Tạo từ tokenizer nếu chưa có
tokenizer = load_shared_tokenizer()
info = {
'vocab_size': tokenizer.get_vocab_size(),
'sos_id': tokenizer.token_to_id("<sos>"),
'eos_id': tokenizer.token_to_id("<eos>"),
'pad_id': tokenizer.token_to_id("<pad>"),
'unk_id': tokenizer.token_to_id("<unk>"),
'tokenizer_path': str(PROCESSED_DATA_DIR / 'tokenizer_shared.json')
}
return info
with open(info_path, 'r', encoding='utf-8') as f:
info = json.load(f)
return info
def load_shared_processed_data(data_path: Optional[Path] = None, use_bidirectional: bool = False) -> dict:
"""
Load processed data đã encode với shared vocabulary
Args:
data_path: Đường dẫn đến processed_data file (mặc định tự tìm)
use_bidirectional: Nếu True, load bidirectional dataset (có cả vi→en và en→vi)
Returns:
processed_data: Dict với keys 'train', 'validation', 'test'
Mỗi value là list of (src_ids, tgt_ids) tuples
"""
if data_path is None:
if use_bidirectional:
data_path = PROCESSED_DATA_DIR / 'processed_data_bidirectional.pkl'
if not data_path.exists():
# Fallback to regular dataset if bidirectional doesn't exist
print("⚠️ Bidirectional dataset không tồn tại, dùng dataset thường")
data_path = PROCESSED_DATA_DIR / 'processed_data_shared.pkl'
else:
data_path = PROCESSED_DATA_DIR / 'processed_data_shared.pkl'
if not data_path.exists():
raise FileNotFoundError(
f"Không tìm thấy data file tại {data_path}!\n"
f"Vui lòng chạy: python src/2_encode_data.py trước"
)
with open(data_path, 'rb') as f:
processed_data = pickle.load(f)
if use_bidirectional and 'bidirectional' in str(data_path):
print(f"✓ Loaded bidirectional dataset từ {data_path.name}")
else:
print(f"✓ Loaded dataset từ {data_path.name}")
return processed_data
# ============================================================================
# POST-PROCESSING: CLEAN DECODED OUTPUT
# ============================================================================
def clean_decoded_output(text: str) -> str:
"""
Clean decoded output để loại bỏ padding artifacts và các lỗi kỹ thuật
Fixes:
1. Loại bỏ khoảng trắng thừa
2. Loại bỏ ký tự lạ (như ‹¶, <pad>, etc.)
3. Normalize spacing quanh dấu câu
4. Strip leading/trailing whitespace
Args:
text: Raw decoded text
Returns:
cleaned: Cleaned text
"""
if not text:
return ""
# Loại bỏ các ký tự đặc biệt không mong muốn
# Loại bỏ các ký tự control và non-printable (trừ space, newline, tab)
text = re.sub(r'[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F-\x9F]', '', text)
# Loại bỏ các ký tự lạ như ‹¶ và các ký tự Unicode không hợp lệ
text = re.sub(r'[‹¶]', '', text)
# Loại bỏ các token đặc biệt nếu còn sót
text = text.replace('<pad>', '').replace('</s>', '').replace('<s>', '')
text = text.replace('<eos>', '').replace('<sos>', '').replace('<unk>', '')
# Normalize spacing quanh dấu câu (loại bỏ space trước dấu câu)
text = re.sub(r'\s+([,.!?;:])', r'\1', text)
# Thêm space sau dấu câu nếu chưa có (trừ khi là cuối câu)
text = re.sub(r'([,.!?;:])([^\s])', r'\1 \2', text)
# Loại bỏ multiple spaces
text = re.sub(r' +', ' ', text)
# Loại bỏ space ở đầu và cuối
text = text.strip()
# Loại bỏ space thừa ở đầu câu (nếu có)
text = text.lstrip()
return text
# ============================================================================
# ENCODE/DECODE HELPERS
# ============================================================================
class SharedVocabulary:
"""
Wrapper class để dùng shared tokenizer như Vocabulary cũ
Tương thích với code hiện tại
"""
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.PAD_TOKEN = "<pad>"
self.SOS_TOKEN = "<sos>"
self.EOS_TOKEN = "<eos>"
self.UNK_TOKEN = "<unk>"
self.PAD_IDX = tokenizer.token_to_id(self.PAD_TOKEN) or 0
self.SOS_IDX = tokenizer.token_to_id(self.SOS_TOKEN) or 1
self.EOS_IDX = tokenizer.token_to_id(self.EOS_TOKEN) or 2
self.UNK_IDX = tokenizer.token_to_id(self.UNK_TOKEN) or 3
def encode(self, sentence: str) -> list:
"""
Encode câu thành list of token IDs
Args:
sentence: Input sentence (string)
Returns:
tokens: List of token IDs [SOS_ID, ...token_ids..., EOS_ID]
"""
# Clean sentence (lowercase, normalize spaces)
sentence = sentence.lower().strip()
# Encode với tokenizer
encoded = self.tokenizer.encode(sentence)
token_ids = encoded.ids
# Thêm SOS và EOS
return [self.SOS_IDX] + token_ids + [self.EOS_IDX]
def decode(self, token_ids: list) -> str:
"""
Decode list of token IDs thành câu
Args:
token_ids: List of token IDs
Returns:
sentence: Decoded sentence (string)
"""
# Loại bỏ SOS, EOS, PAD
filtered_ids = [
idx for idx in token_ids
if idx not in [self.PAD_IDX, self.SOS_IDX, self.EOS_IDX]
]
if not filtered_ids:
return ""
# Decode với tokenizer
decoded = self.tokenizer.decode(filtered_ids, skip_special_tokens=True)
# Post-processing: Clean output
decoded = clean_decoded_output(decoded)
return decoded
def __len__(self):
"""Kích thước vocabulary"""
return self.tokenizer.get_vocab_size()
def token_to_id(self, token: str) -> int:
"""Convert token string to ID"""
return self.tokenizer.token_to_id(token) or self.UNK_IDX
def id_to_token(self, idx: int) -> str:
"""Convert ID to token string"""
return self.tokenizer.id_to_token(idx) or self.UNK_TOKEN
def create_shared_vocab_wrapper() -> Tuple[SharedVocabulary, SharedVocabulary]:
"""
Tạo wrapper cho shared vocabulary (tương thích với code cũ)
Returns:
vi_vocab, en_vocab: Cả 2 đều là SharedVocabulary (cùng tokenizer)
"""
tokenizer = load_shared_tokenizer()
# Cả 2 ngôn ngữ dùng chung tokenizer
vi_vocab = SharedVocabulary(tokenizer)
en_vocab = SharedVocabulary(tokenizer)
return vi_vocab, en_vocab
# ============================================================================
# CHECK SHARED VOCAB SETUP
# ============================================================================
def check_shared_vocab_setup() -> bool:
"""
Kiểm tra xem shared vocabulary đã được setup chưa
Returns:
is_setup: True nếu đã setup đầy đủ
"""
tokenizer_path = PROCESSED_DATA_DIR / 'tokenizer_shared.json'
data_path = PROCESSED_DATA_DIR / 'processed_data_shared.pkl'
return tokenizer_path.exists() and data_path.exists()
def print_shared_vocab_status():
"""
In trạng thái setup của shared vocabulary
"""
print("="*70)
print("KIỂM TRA SHARED VOCABULARY SETUP")
print("="*70)
tokenizer_path = PROCESSED_DATA_DIR / 'tokenizer_shared.json'
data_path = PROCESSED_DATA_DIR / 'processed_data_shared.pkl'
info_path = PROCESSED_DATA_DIR / 'shared_vocab_info.json'
print(f"\n1. Tokenizer:")
if tokenizer_path.exists():
print(f" ✓ {tokenizer_path}")
try:
tokenizer = load_shared_tokenizer()
print(f" ✓ Vocab size: {tokenizer.get_vocab_size()}")
except Exception as e:
print(f" ✗ Lỗi khi load: {e}")
else:
print(f" ✗ Không tìm thấy: {tokenizer_path}")
print(f" → Chạy: python src/1_build_shared_vocab.py")
print(f"\n2. Processed Data:")
if data_path.exists():
print(f" ✓ {data_path}")
try:
data = load_shared_processed_data()
print(f" ✓ Train: {len(data.get('train', []))} cặp câu")
print(f" ✓ Validation: {len(data.get('validation', []))} cặp câu")
print(f" ✓ Test: {len(data.get('test', []))} cặp câu")
except Exception as e:
print(f" ✗ Lỗi khi load: {e}")
else:
print(f" ✗ Không tìm thấy: {data_path}")
print(f" → Chạy: python src/2_encode_data.py")
print(f"\n3. Info File:")
if info_path.exists():
print(f" ✓ {info_path}")
else:
print(f" ⚠️ Chưa có (sẽ tự tạo khi cần)")
print("\n" + "="*70)
if check_shared_vocab_setup():
print("✓ SHARED VOCABULARY ĐÃ SẴN SÀNG!")
else:
print("⚠️ SHARED VOCABULARY CHƯA ĐƯỢC SETUP")
print("\n📝 Các bước cần làm:")
print(" 1. python src/1_build_shared_vocab.py")
print(" 2. python src/2_encode_data.py")
print("="*70)
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
print_shared_vocab_status()