Cong123779 commited on
Commit
c87fc44
·
verified ·
1 Parent(s): d81700b

Add README.md

Browse files
Files changed (1) hide show
  1. README.md +125 -0
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - vi
4
+ - en
5
+ license: apache-2.0
6
+ tags:
7
+ - asr
8
+ - automatic-speech-recognition
9
+ - transformer
10
+ - vietnamese
11
+ - english
12
+ - bilingual
13
+ datasets:
14
+ - Cong123779/AI2Text-Bilingual-ASR-Dataset
15
+ metrics:
16
+ - wer
17
+ - cer
18
+ ---
19
+
20
+ # AI2Text – Bilingual ASR (Vietnamese + English)
21
+
22
+ A **~30M-parameter** Transformer Seq2Seq Automatic Speech Recognition model
23
+ trained on **~224k** bilingual (Vietnamese + English) audio samples.
24
+
25
+ ## Model Description
26
+
27
+ | Attribute | Value |
28
+ |---|---|
29
+ | Architecture | Encoder-Decoder Transformer |
30
+ | Parameters | ~30,325,164 |
31
+ | d_model | 256 |
32
+ | Encoder layers | 14 (RoPE + Flash Attention) |
33
+ | Decoder layers | 6 (causal, cross-attention) |
34
+ | Vocabulary size | 3,500 (SentencePiece BPE) |
35
+ | Language embedding | Yes (Vietnamese=0, English=1) |
36
+ | Normalization | RMSNorm |
37
+ | Activation | SiLU (Swish) |
38
+ | Positional encoding | Rotary (RoPE) |
39
+
40
+ ### Modern Components
41
+ - **RMSNorm** – more efficient than LayerNorm
42
+ - **SiLU (Swish)** activation
43
+ - **Rotary Positional Embedding (RoPE)** – better generalization
44
+ - **Flash Attention (SDPA)** – memory-efficient attention
45
+ - **Hybrid CTC / Attention loss** – helps encoder learn alignment
46
+
47
+ ## Training Data
48
+
49
+ Trained on `Cong123779/AI2Text-Bilingual-ASR-Dataset`:
50
+ - **Train**: ~194,167 samples (77% Vietnamese, 23% English)
51
+ - **Validation**: ~30,123 samples
52
+
53
+ Audio format: 16 kHz mono WAV, 80-dim Mel-spectrogram features.
54
+
55
+ ## Training Configuration
56
+
57
+ | Hyperparameter | Value |
58
+ |---|---|
59
+ | Batch size | 32 (effective 128 w/ grad-accum × 4) |
60
+ | Learning rate | 3e-4 |
61
+ | Epochs | 50 |
62
+ | Warmup | 3% of training steps |
63
+ | Mixed precision | bfloat16 (AMP) |
64
+ | Gradient clipping | 0.5 |
65
+ | CTC weight | 0.2 |
66
+ | Scheduled sampling | 1.0 → 0.5 (linear) |
67
+
68
+ ## Usage
69
+
70
+ ```python
71
+ import torch
72
+ from pathlib import Path
73
+ import sys
74
+
75
+ # Clone the repo and add to path
76
+ sys.path.insert(0, "AI2Text")
77
+
78
+ from models.asr_base import ASRModel
79
+ from preprocessing.sentencepiece_tokenizer import SentencePieceTokenizer
80
+ from preprocessing.audio_processing import AudioProcessor
81
+
82
+ # Load tokenizer
83
+ tokenizer = SentencePieceTokenizer("models/tokenizer_vi_en_3500.model")
84
+
85
+ # Load model
86
+ checkpoint = torch.load("best_model.pt", map_location="cpu")
87
+ config = checkpoint.get("config", {})
88
+
89
+ model = ASRModel(
90
+ input_dim=80,
91
+ vocab_size=3500,
92
+ d_model=256,
93
+ num_encoder_layers=14,
94
+ num_decoder_layers=6,
95
+ num_heads=8,
96
+ d_ff=2048,
97
+ num_languages=2,
98
+ )
99
+ model.load_state_dict(checkpoint["model_state_dict"])
100
+ model.eval()
101
+
102
+ # Transcribe
103
+ audio_processor = AudioProcessor(sample_rate=16000, n_mels=80)
104
+ features = audio_processor.process("audio.wav") # (time, 80)
105
+ features = features.unsqueeze(0) # (1, time, 80)
106
+ lengths = torch.tensor([features.size(1)])
107
+
108
+ with torch.no_grad():
109
+ tokens = model.generate(
110
+ features, lengths=lengths,
111
+ language_ids=torch.tensor([0]), # 0=vi, 1=en
112
+ max_len=128,
113
+ sos_token_id=tokenizer.sos_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ pad_token_id=tokenizer.pad_token_id,
116
+ )
117
+ text = tokenizer.decode(tokens[0].tolist())
118
+ print(text)
119
+ ```
120
+
121
+ ## Framework
122
+ Built with PyTorch. Optimized for **RTX 5060TI 16GB / Ryzen 9 9990X / 64GB RAM**.
123
+
124
+ ## License
125
+ Apache 2.0