File size: 7,321 Bytes
b5e57ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | import os
import io
import json
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel
from omegaconf import open_dict , DictConfig
# ============================================================
# Environment Fixes (Windows / CUDA)
# ============================================================
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["NUMBA_CUDA_USE_NVIDIA_BINDING"] = "1"
os.environ["NUMBA_DISABLE_JIT"] = "0"
os.environ["NUMBA_CUDA_DRIVER"] = "cuda"
# Uncomment to use GPU (recommended for RTX 3070)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ============================================================
# UTF-8 Fix for Manifest
# ============================================================
manifest_path = "train_manifest.jsonl"
with io.open(manifest_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
with io.open(manifest_path, 'w', encoding='utf-8') as f:
f.write(content)
print("✅ train_manifest.jsonl converted to UTF-8")
# Patch builtins.open for UTF-8
import builtins
_old_open = open
def open_utf8(file, *args, **kwargs):
if isinstance(file, str) and file.endswith('.jsonl') and 'encoding' not in kwargs:
kwargs['encoding'] = 'utf-8'
return _old_open(file, *args, **kwargs)
builtins.open = open_utf8
# ============================================================
# Validate Manifest
# ============================================================
def validate_manifest(manifest_path):
count = 0
with open(manifest_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f, 1):
try:
item = json.loads(line.strip())
assert os.path.exists(item["audio_filepath"]), f"Missing: {item['audio_filepath']}"
assert "text" in item and item["text"].strip(), "Empty text"
count += 1
except Exception as e:
print(f"❌ Line {i} error: {e}")
print(f" Content: {line[:100]}")
print(f"✅ Valid entries: {count}")
return count
valid_count = validate_manifest(manifest_path)
if valid_count == 0:
raise ValueError("No valid training samples found!")
# ============================================================
# Paths and Hyperparameters
# ============================================================
BASE_MODEL_PATH = "stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo"
SAVE_DIR = "output_finetuned"
LAST_CKPT = os.path.join(SAVE_DIR, "last.ckpt")
BATCH_SIZE = 4
ADDITIONAL_EPOCHS = 50
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.00001
os.makedirs(SAVE_DIR, exist_ok=True)
# ============================================================
# Load Model
# ============================================================
print("🔹 Loading pretrained or last fine-tuned model...")
model = EncDecHybridRNNTCTCBPEModel.restore_from(BASE_MODEL_PATH)
# ============================================================
# Tokenizer Fix
# ============================================================
with open_dict(model.cfg):
tokenizer_dir = os.path.join(os.path.dirname(BASE_MODEL_PATH), "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
model.cfg.tokenizer.dir = tokenizer_dir
model.cfg.tokenizer.type = "bpe"
if 'validation_ds' in model.cfg:
model.cfg.validation_ds.manifest_filepath = None
if 'test_ds' in model.cfg:
model.cfg.test_ds.manifest_filepath = None
# ============================================================
# Setup Training Data
# ============================================================
train_ds_config = {
"manifest_filepath": manifest_path,
"batch_size": BATCH_SIZE,
"shuffle": True,
"num_workers": 0,
"pin_memory": False,
"sample_rate": 16000,
"max_duration": 20.0,
"min_duration": 0.5,
"trim_silence": True,
"use_start_end_token": True,
"normalize_transcripts": True,
"parser": "ar",
}
model.setup_training_data(train_ds_config)
# ============================================================
# Optimizer & Scheduler
# ============================================================
with open_dict(model.cfg):
model.cfg.optim.name = "adamw"
model.cfg.optim.lr = LEARNING_RATE
model.cfg.optim.betas = [0.9, 0.98]
model.cfg.optim.weight_decay = WEIGHT_DECAY
model.cfg.optim.eps = 1e-8
model.cfg.optim.sched = {
"name": "CosineAnnealing",
"warmup_steps": WARMUP_STEPS,
"min_lr": 1e-7,
"last_epoch": -1,
}
# ============================================================
# Callbacks
# ============================================================
checkpoint_callback = ModelCheckpoint(
dirpath=SAVE_DIR,
filename='continued-{epoch:02d}-{train_loss:.4f}',
save_top_k=3,
monitor='train_loss',
mode='min',
save_last=True,
)
early_stop_callback = EarlyStopping(
monitor='train_loss',
patience=20,
mode='min',
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval='step')
# ============================================================
# Determine Max Epochs Based on Last Checkpoint
# ============================================================
# ============================================================
# Allow loading full NeMo checkpoint (trusted source)
# ============================================================
torch.serialization.add_safe_globals([DictConfig])
if os.path.exists(LAST_CKPT):
ckpt_data = torch.load(LAST_CKPT, map_location="cpu", weights_only=False)
last_epoch = ckpt_data.get("epoch", 0)
new_max_epochs = last_epoch + ADDITIONAL_EPOCHS
print(f"🧩 Last checkpoint epoch: {last_epoch} → continuing up to {new_max_epochs} epochs total.")
else:
new_max_epochs = ADDITIONAL_EPOCHS
# ============================================================
# Trainer
# ============================================================
trainer = Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
max_epochs=new_max_epochs,
log_every_n_steps=1,
enable_checkpointing=True,
default_root_dir=SAVE_DIR,
callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
gradient_clip_val=1.0,
accumulate_grad_batches=4,
)
# ============================================================
# Continue Training
# ============================================================
if os.path.exists(LAST_CKPT):
print(f"🚀 Continuing training from checkpoint: {LAST_CKPT}")
trainer.fit(model, ckpt_path=LAST_CKPT)
else:
print("⚠️ No checkpoint found, training from base model...")
trainer.fit(model)
# ============================================================
# Save Final Model
# ============================================================
final_model_path = os.path.join(SAVE_DIR, "finetuned_model_continued.nemo")
model.save_to(final_model_path)
print(f"\n✅ Continued fine-tuned model saved to: {final_model_path}")
|