Arabic_Finetuned_ASR_Nemo / continue_finetuning_nemo.py
alaatiger989's picture
Add files using upload-large-folder tool
b5e57ee verified
import os
import io
import json
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel
from omegaconf import open_dict , DictConfig
# ============================================================
# Environment Fixes (Windows / CUDA)
# ============================================================
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["NUMBA_CUDA_USE_NVIDIA_BINDING"] = "1"
os.environ["NUMBA_DISABLE_JIT"] = "0"
os.environ["NUMBA_CUDA_DRIVER"] = "cuda"
# Uncomment to use GPU (recommended for RTX 3070)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ============================================================
# UTF-8 Fix for Manifest
# ============================================================
manifest_path = "train_manifest.jsonl"
with io.open(manifest_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
with io.open(manifest_path, 'w', encoding='utf-8') as f:
f.write(content)
print("✅ train_manifest.jsonl converted to UTF-8")
# Patch builtins.open for UTF-8
import builtins
_old_open = open
def open_utf8(file, *args, **kwargs):
if isinstance(file, str) and file.endswith('.jsonl') and 'encoding' not in kwargs:
kwargs['encoding'] = 'utf-8'
return _old_open(file, *args, **kwargs)
builtins.open = open_utf8
# ============================================================
# Validate Manifest
# ============================================================
def validate_manifest(manifest_path):
count = 0
with open(manifest_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f, 1):
try:
item = json.loads(line.strip())
assert os.path.exists(item["audio_filepath"]), f"Missing: {item['audio_filepath']}"
assert "text" in item and item["text"].strip(), "Empty text"
count += 1
except Exception as e:
print(f"❌ Line {i} error: {e}")
print(f" Content: {line[:100]}")
print(f"✅ Valid entries: {count}")
return count
valid_count = validate_manifest(manifest_path)
if valid_count == 0:
raise ValueError("No valid training samples found!")
# ============================================================
# Paths and Hyperparameters
# ============================================================
BASE_MODEL_PATH = "stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo"
SAVE_DIR = "output_finetuned"
LAST_CKPT = os.path.join(SAVE_DIR, "last.ckpt")
BATCH_SIZE = 4
ADDITIONAL_EPOCHS = 50
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.00001
os.makedirs(SAVE_DIR, exist_ok=True)
# ============================================================
# Load Model
# ============================================================
print("🔹 Loading pretrained or last fine-tuned model...")
model = EncDecHybridRNNTCTCBPEModel.restore_from(BASE_MODEL_PATH)
# ============================================================
# Tokenizer Fix
# ============================================================
with open_dict(model.cfg):
tokenizer_dir = os.path.join(os.path.dirname(BASE_MODEL_PATH), "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
model.cfg.tokenizer.dir = tokenizer_dir
model.cfg.tokenizer.type = "bpe"
if 'validation_ds' in model.cfg:
model.cfg.validation_ds.manifest_filepath = None
if 'test_ds' in model.cfg:
model.cfg.test_ds.manifest_filepath = None
# ============================================================
# Setup Training Data
# ============================================================
train_ds_config = {
"manifest_filepath": manifest_path,
"batch_size": BATCH_SIZE,
"shuffle": True,
"num_workers": 0,
"pin_memory": False,
"sample_rate": 16000,
"max_duration": 20.0,
"min_duration": 0.5,
"trim_silence": True,
"use_start_end_token": True,
"normalize_transcripts": True,
"parser": "ar",
}
model.setup_training_data(train_ds_config)
# ============================================================
# Optimizer & Scheduler
# ============================================================
with open_dict(model.cfg):
model.cfg.optim.name = "adamw"
model.cfg.optim.lr = LEARNING_RATE
model.cfg.optim.betas = [0.9, 0.98]
model.cfg.optim.weight_decay = WEIGHT_DECAY
model.cfg.optim.eps = 1e-8
model.cfg.optim.sched = {
"name": "CosineAnnealing",
"warmup_steps": WARMUP_STEPS,
"min_lr": 1e-7,
"last_epoch": -1,
}
# ============================================================
# Callbacks
# ============================================================
checkpoint_callback = ModelCheckpoint(
dirpath=SAVE_DIR,
filename='continued-{epoch:02d}-{train_loss:.4f}',
save_top_k=3,
monitor='train_loss',
mode='min',
save_last=True,
)
early_stop_callback = EarlyStopping(
monitor='train_loss',
patience=20,
mode='min',
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval='step')
# ============================================================
# Determine Max Epochs Based on Last Checkpoint
# ============================================================
# ============================================================
# Allow loading full NeMo checkpoint (trusted source)
# ============================================================
torch.serialization.add_safe_globals([DictConfig])
if os.path.exists(LAST_CKPT):
ckpt_data = torch.load(LAST_CKPT, map_location="cpu", weights_only=False)
last_epoch = ckpt_data.get("epoch", 0)
new_max_epochs = last_epoch + ADDITIONAL_EPOCHS
print(f"🧩 Last checkpoint epoch: {last_epoch} → continuing up to {new_max_epochs} epochs total.")
else:
new_max_epochs = ADDITIONAL_EPOCHS
# ============================================================
# Trainer
# ============================================================
trainer = Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
max_epochs=new_max_epochs,
log_every_n_steps=1,
enable_checkpointing=True,
default_root_dir=SAVE_DIR,
callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
gradient_clip_val=1.0,
accumulate_grad_batches=4,
)
# ============================================================
# Continue Training
# ============================================================
if os.path.exists(LAST_CKPT):
print(f"🚀 Continuing training from checkpoint: {LAST_CKPT}")
trainer.fit(model, ckpt_path=LAST_CKPT)
else:
print("⚠️ No checkpoint found, training from base model...")
trainer.fit(model)
# ============================================================
# Save Final Model
# ============================================================
final_model_path = os.path.join(SAVE_DIR, "finetuned_model_continued.nemo")
model.save_to(final_model_path)
print(f"\n✅ Continued fine-tuned model saved to: {final_model_path}")