import os import io import json import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel from omegaconf import open_dict , DictConfig # ============================================================ # Environment Fixes (Windows / CUDA) # ============================================================ os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["NUMBA_CUDA_USE_NVIDIA_BINDING"] = "1" os.environ["NUMBA_DISABLE_JIT"] = "0" os.environ["NUMBA_CUDA_DRIVER"] = "cuda" # Uncomment to use GPU (recommended for RTX 3070) os.environ["CUDA_VISIBLE_DEVICES"] = "0" # ============================================================ # UTF-8 Fix for Manifest # ============================================================ manifest_path = "train_manifest.jsonl" with io.open(manifest_path, 'r', encoding='utf-8', errors='ignore') as f: content = f.read() with io.open(manifest_path, 'w', encoding='utf-8') as f: f.write(content) print("āœ… train_manifest.jsonl converted to UTF-8") # Patch builtins.open for UTF-8 import builtins _old_open = open def open_utf8(file, *args, **kwargs): if isinstance(file, str) and file.endswith('.jsonl') and 'encoding' not in kwargs: kwargs['encoding'] = 'utf-8' return _old_open(file, *args, **kwargs) builtins.open = open_utf8 # ============================================================ # Validate Manifest # ============================================================ def validate_manifest(manifest_path): count = 0 with open(manifest_path, "r", encoding="utf-8") as f: for i, line in enumerate(f, 1): try: item = json.loads(line.strip()) assert os.path.exists(item["audio_filepath"]), f"Missing: {item['audio_filepath']}" assert "text" in item and item["text"].strip(), "Empty text" count += 1 except Exception as e: print(f"āŒ Line {i} error: {e}") print(f" Content: {line[:100]}") print(f"āœ… Valid entries: {count}") return count valid_count = validate_manifest(manifest_path) if valid_count == 0: raise ValueError("No valid training samples found!") # ============================================================ # Paths and Hyperparameters # ============================================================ BASE_MODEL_PATH = "stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo" SAVE_DIR = "output_finetuned" LAST_CKPT = os.path.join(SAVE_DIR, "last.ckpt") BATCH_SIZE = 4 ADDITIONAL_EPOCHS = 50 LEARNING_RATE = 1e-5 WARMUP_STEPS = 500 WEIGHT_DECAY = 0.00001 os.makedirs(SAVE_DIR, exist_ok=True) # ============================================================ # Load Model # ============================================================ print("šŸ”¹ Loading pretrained or last fine-tuned model...") model = EncDecHybridRNNTCTCBPEModel.restore_from(BASE_MODEL_PATH) # ============================================================ # Tokenizer Fix # ============================================================ with open_dict(model.cfg): tokenizer_dir = os.path.join(os.path.dirname(BASE_MODEL_PATH), "tokenizer") os.makedirs(tokenizer_dir, exist_ok=True) model.cfg.tokenizer.dir = tokenizer_dir model.cfg.tokenizer.type = "bpe" if 'validation_ds' in model.cfg: model.cfg.validation_ds.manifest_filepath = None if 'test_ds' in model.cfg: model.cfg.test_ds.manifest_filepath = None # ============================================================ # Setup Training Data # ============================================================ train_ds_config = { "manifest_filepath": manifest_path, "batch_size": BATCH_SIZE, "shuffle": True, "num_workers": 0, "pin_memory": False, "sample_rate": 16000, "max_duration": 20.0, "min_duration": 0.5, "trim_silence": True, "use_start_end_token": True, "normalize_transcripts": True, "parser": "ar", } model.setup_training_data(train_ds_config) # ============================================================ # Optimizer & Scheduler # ============================================================ with open_dict(model.cfg): model.cfg.optim.name = "adamw" model.cfg.optim.lr = LEARNING_RATE model.cfg.optim.betas = [0.9, 0.98] model.cfg.optim.weight_decay = WEIGHT_DECAY model.cfg.optim.eps = 1e-8 model.cfg.optim.sched = { "name": "CosineAnnealing", "warmup_steps": WARMUP_STEPS, "min_lr": 1e-7, "last_epoch": -1, } # ============================================================ # Callbacks # ============================================================ checkpoint_callback = ModelCheckpoint( dirpath=SAVE_DIR, filename='continued-{epoch:02d}-{train_loss:.4f}', save_top_k=3, monitor='train_loss', mode='min', save_last=True, ) early_stop_callback = EarlyStopping( monitor='train_loss', patience=20, mode='min', verbose=True, ) lr_monitor = LearningRateMonitor(logging_interval='step') # ============================================================ # Determine Max Epochs Based on Last Checkpoint # ============================================================ # ============================================================ # Allow loading full NeMo checkpoint (trusted source) # ============================================================ torch.serialization.add_safe_globals([DictConfig]) if os.path.exists(LAST_CKPT): ckpt_data = torch.load(LAST_CKPT, map_location="cpu", weights_only=False) last_epoch = ckpt_data.get("epoch", 0) new_max_epochs = last_epoch + ADDITIONAL_EPOCHS print(f"🧩 Last checkpoint epoch: {last_epoch} → continuing up to {new_max_epochs} epochs total.") else: new_max_epochs = ADDITIONAL_EPOCHS # ============================================================ # Trainer # ============================================================ trainer = Trainer( accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=1, max_epochs=new_max_epochs, log_every_n_steps=1, enable_checkpointing=True, default_root_dir=SAVE_DIR, callbacks=[checkpoint_callback, early_stop_callback, lr_monitor], gradient_clip_val=1.0, accumulate_grad_batches=4, ) # ============================================================ # Continue Training # ============================================================ if os.path.exists(LAST_CKPT): print(f"šŸš€ Continuing training from checkpoint: {LAST_CKPT}") trainer.fit(model, ckpt_path=LAST_CKPT) else: print("āš ļø No checkpoint found, training from base model...") trainer.fit(model) # ============================================================ # Save Final Model # ============================================================ final_model_path = os.path.join(SAVE_DIR, "finetuned_model_continued.nemo") model.save_to(final_model_path) print(f"\nāœ… Continued fine-tuned model saved to: {final_model_path}")