Cong123779 commited on
Commit
d81700b
·
verified ·
1 Parent(s): c9034a1

Add config/default.yaml

Browse files
Files changed (1) hide show
  1. config/default.yaml +104 -0
config/default.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default configuration for ASR training
2
+ # BILINGUAL TRAINING: Vietnamese + English
3
+ # Optimized for full_merged_dataset with bilingual support
4
+ # Dataset: ~194k training samples (77% Vietnamese, 23% English)
5
+
6
+ # Model architecture - Transformer Seq2Seq ASR with Language Embedding
7
+ # OPTIMIZED FOR BILINGUAL (Vietnamese + English) - ~30M parameters
8
+ model_name: "VietnameseASR_Transformer_Bilingual_30M"
9
+ d_model: 256 # Model dimension (reduced from 320)
10
+ num_encoder_layers: 14 # Number of encoder layers (kept same)
11
+ num_decoder_layers: 6 # Number of decoder layers (kept same)
12
+ num_heads: 8 # Number of attention heads (kept same)
13
+ d_ff: 2048 # Feed-forward dimension (reduced from 3120)
14
+ dropout: 0.2 # Dropout rate
15
+
16
+ # Audio processing
17
+ sample_rate: 16000
18
+ n_mels: 80
19
+ n_fft: 400
20
+ hop_length: 160
21
+ win_length: 400
22
+
23
+ # Tokenization - SentencePiece BPE for Bilingual (Vietnamese + English)
24
+ tokenizer_type: "sentencepiece" # Changed from "bpe" to "sentencepiece"
25
+ bpe_vocab_path: "models/tokenizer_vi_en_3500.model" # SentencePiece .model file
26
+ vocab_size: 3500
27
+
28
+ # Training hyperparameters
29
+ batch_size: 32 # Giảm xuống 32 để tránh OOM với CTC loss - effective batch size: 128 (32 * 4)
30
+ val_batch_size: 64 # Validation batch size (reduced proportionally)
31
+ # Tăng tổng số epoch để resume vượt mốc 20
32
+ num_epochs: 50
33
+ learning_rate: 0.0003
34
+ weight_decay: 0.0001
35
+ grad_clip: 0.5
36
+ gradient_accumulation_steps: 4 # Kept same - effective batch size: 256 (64 * 4)
37
+ warmup_pct: 0.03 # Giảm từ 10% xuống 3% để model học nhanh hơn
38
+ use_constant_lr_on_resume: false
39
+
40
+ # Optimization
41
+ use_amp: true # Bật mixed precision
42
+ use_bf16: true # Sử dụng bfloat16 (tốt hơn float16 về numerical stability, RTX 5060TI hỗ trợ)
43
+ num_workers: 2 # Giảm từ 12 xuống 2 để tránh BrokenPipeError (như đã thấy ở Epoch 6, 18)
44
+ pin_memory: true
45
+ use_gradient_checkpointing: false # Tắt tạm thời vì có conflict với CTC output
46
+ prefetch_factor: 4
47
+ persistent_workers: true
48
+ sort_by_length: true
49
+ cache_in_ram: false
50
+ use_bucketing: false
51
+
52
+ # Data
53
+ dataset_root: "data/processed/full_merged_dataset"
54
+ language_filter: null
55
+
56
+ # Decoding - Seq2Seq Autoregressive Generation
57
+ # Using autoregressive generation with teacher forcing during training
58
+
59
+ # Hybrid CTC/Attention Training (FIXES: Forces encoder to learn alignment)
60
+ use_ctc_loss: true # Enable CTC loss to help encoder learn audio-text alignment
61
+ ctc_weight: 0.2 # Weight for CTC loss (0.2 = 20% CTC, 80% Attention) - Giảm để tiết kiệm memory
62
+
63
+ # Scheduled Sampling (FIXES: Reduces teacher forcing, forces model to use encoder)
64
+ use_scheduled_sampling: true # Enable scheduled sampling to reduce teacher forcing
65
+ teacher_forcing_initial: 1.0 # Start with 100% teacher forcing
66
+ teacher_forcing_final: 0.5 # End with 50% teacher forcing (gradual decay)
67
+
68
+ # Checkpointing
69
+ checkpoint_dir: "checkpoints"
70
+ save_every: 1 # Save checkpoint after every epoch
71
+
72
+ # Logging
73
+ log_file: "logs/training.log"
74
+
75
+ # Training run
76
+ run_name: "vietnamese_asr_transformer_bilingual_30m"
77
+
78
+ # Auto-Rollback
79
+ auto_rollback:
80
+ enabled: true
81
+ threshold_ratio: 1.3
82
+ patience: 1
83
+
84
+ # Curriculum Learning
85
+ curriculum_learning:
86
+ enabled: true
87
+ required_wer: 0.70
88
+ initial_ts_weight: 0.01
89
+ short_sentence_epochs: 3
90
+ max_duration_seconds: 4.0
91
+
92
+ # Validation decoding controls
93
+ # Limit decode length to speed up validation (prevents infinite loops)
94
+ val_max_len: 128
95
+ # Validate on a subset of validation batches (set null to disable)
96
+ val_subset_pct: null
97
+ # Hard-cap number of validation batches (set null to validate on full val set)
98
+ val_max_batches: null
99
+ # Use autoregressive generation for validation (slower but more accurate)
100
+ # If false, uses greedy decoding from logits (faster, ~2x speedup, avoids second forward pass)
101
+ use_autoregressive_validation: false
102
+ # Calculate WER/CER during validation (set to false to skip prediction and speed up validation)
103
+ calculate_val_wer: false # Tắt tính WER để validation nhanh hơn (chỉ tính loss)
104
+