Add config/default.yaml
Browse files- config/default.yaml +104 -0
config/default.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default configuration for ASR training
|
| 2 |
+
# BILINGUAL TRAINING: Vietnamese + English
|
| 3 |
+
# Optimized for full_merged_dataset with bilingual support
|
| 4 |
+
# Dataset: ~194k training samples (77% Vietnamese, 23% English)
|
| 5 |
+
|
| 6 |
+
# Model architecture - Transformer Seq2Seq ASR with Language Embedding
|
| 7 |
+
# OPTIMIZED FOR BILINGUAL (Vietnamese + English) - ~30M parameters
|
| 8 |
+
model_name: "VietnameseASR_Transformer_Bilingual_30M"
|
| 9 |
+
d_model: 256 # Model dimension (reduced from 320)
|
| 10 |
+
num_encoder_layers: 14 # Number of encoder layers (kept same)
|
| 11 |
+
num_decoder_layers: 6 # Number of decoder layers (kept same)
|
| 12 |
+
num_heads: 8 # Number of attention heads (kept same)
|
| 13 |
+
d_ff: 2048 # Feed-forward dimension (reduced from 3120)
|
| 14 |
+
dropout: 0.2 # Dropout rate
|
| 15 |
+
|
| 16 |
+
# Audio processing
|
| 17 |
+
sample_rate: 16000
|
| 18 |
+
n_mels: 80
|
| 19 |
+
n_fft: 400
|
| 20 |
+
hop_length: 160
|
| 21 |
+
win_length: 400
|
| 22 |
+
|
| 23 |
+
# Tokenization - SentencePiece BPE for Bilingual (Vietnamese + English)
|
| 24 |
+
tokenizer_type: "sentencepiece" # Changed from "bpe" to "sentencepiece"
|
| 25 |
+
bpe_vocab_path: "models/tokenizer_vi_en_3500.model" # SentencePiece .model file
|
| 26 |
+
vocab_size: 3500
|
| 27 |
+
|
| 28 |
+
# Training hyperparameters
|
| 29 |
+
batch_size: 32 # Giảm xuống 32 để tránh OOM với CTC loss - effective batch size: 128 (32 * 4)
|
| 30 |
+
val_batch_size: 64 # Validation batch size (reduced proportionally)
|
| 31 |
+
# Tăng tổng số epoch để resume vượt mốc 20
|
| 32 |
+
num_epochs: 50
|
| 33 |
+
learning_rate: 0.0003
|
| 34 |
+
weight_decay: 0.0001
|
| 35 |
+
grad_clip: 0.5
|
| 36 |
+
gradient_accumulation_steps: 4 # Kept same - effective batch size: 256 (64 * 4)
|
| 37 |
+
warmup_pct: 0.03 # Giảm từ 10% xuống 3% để model học nhanh hơn
|
| 38 |
+
use_constant_lr_on_resume: false
|
| 39 |
+
|
| 40 |
+
# Optimization
|
| 41 |
+
use_amp: true # Bật mixed precision
|
| 42 |
+
use_bf16: true # Sử dụng bfloat16 (tốt hơn float16 về numerical stability, RTX 5060TI hỗ trợ)
|
| 43 |
+
num_workers: 2 # Giảm từ 12 xuống 2 để tránh BrokenPipeError (như đã thấy ở Epoch 6, 18)
|
| 44 |
+
pin_memory: true
|
| 45 |
+
use_gradient_checkpointing: false # Tắt tạm thời vì có conflict với CTC output
|
| 46 |
+
prefetch_factor: 4
|
| 47 |
+
persistent_workers: true
|
| 48 |
+
sort_by_length: true
|
| 49 |
+
cache_in_ram: false
|
| 50 |
+
use_bucketing: false
|
| 51 |
+
|
| 52 |
+
# Data
|
| 53 |
+
dataset_root: "data/processed/full_merged_dataset"
|
| 54 |
+
language_filter: null
|
| 55 |
+
|
| 56 |
+
# Decoding - Seq2Seq Autoregressive Generation
|
| 57 |
+
# Using autoregressive generation with teacher forcing during training
|
| 58 |
+
|
| 59 |
+
# Hybrid CTC/Attention Training (FIXES: Forces encoder to learn alignment)
|
| 60 |
+
use_ctc_loss: true # Enable CTC loss to help encoder learn audio-text alignment
|
| 61 |
+
ctc_weight: 0.2 # Weight for CTC loss (0.2 = 20% CTC, 80% Attention) - Giảm để tiết kiệm memory
|
| 62 |
+
|
| 63 |
+
# Scheduled Sampling (FIXES: Reduces teacher forcing, forces model to use encoder)
|
| 64 |
+
use_scheduled_sampling: true # Enable scheduled sampling to reduce teacher forcing
|
| 65 |
+
teacher_forcing_initial: 1.0 # Start with 100% teacher forcing
|
| 66 |
+
teacher_forcing_final: 0.5 # End with 50% teacher forcing (gradual decay)
|
| 67 |
+
|
| 68 |
+
# Checkpointing
|
| 69 |
+
checkpoint_dir: "checkpoints"
|
| 70 |
+
save_every: 1 # Save checkpoint after every epoch
|
| 71 |
+
|
| 72 |
+
# Logging
|
| 73 |
+
log_file: "logs/training.log"
|
| 74 |
+
|
| 75 |
+
# Training run
|
| 76 |
+
run_name: "vietnamese_asr_transformer_bilingual_30m"
|
| 77 |
+
|
| 78 |
+
# Auto-Rollback
|
| 79 |
+
auto_rollback:
|
| 80 |
+
enabled: true
|
| 81 |
+
threshold_ratio: 1.3
|
| 82 |
+
patience: 1
|
| 83 |
+
|
| 84 |
+
# Curriculum Learning
|
| 85 |
+
curriculum_learning:
|
| 86 |
+
enabled: true
|
| 87 |
+
required_wer: 0.70
|
| 88 |
+
initial_ts_weight: 0.01
|
| 89 |
+
short_sentence_epochs: 3
|
| 90 |
+
max_duration_seconds: 4.0
|
| 91 |
+
|
| 92 |
+
# Validation decoding controls
|
| 93 |
+
# Limit decode length to speed up validation (prevents infinite loops)
|
| 94 |
+
val_max_len: 128
|
| 95 |
+
# Validate on a subset of validation batches (set null to disable)
|
| 96 |
+
val_subset_pct: null
|
| 97 |
+
# Hard-cap number of validation batches (set null to validate on full val set)
|
| 98 |
+
val_max_batches: null
|
| 99 |
+
# Use autoregressive generation for validation (slower but more accurate)
|
| 100 |
+
# If false, uses greedy decoding from logits (faster, ~2x speedup, avoids second forward pass)
|
| 101 |
+
use_autoregressive_validation: false
|
| 102 |
+
# Calculate WER/CER during validation (set to false to skip prediction and speed up validation)
|
| 103 |
+
calculate_val_wer: false # Tắt tính WER để validation nhanh hơn (chỉ tính loss)
|
| 104 |
+
|