""" Benchmark inference performance and quality comparing Base vs Merged models. Logs natively to WandB for side-by-side visualization. """ import sys import time import wandb import logging from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from src.inference.transcribe import WhisperTranscriber logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) def main(): root = Path(__file__).parent.parent audio_files = ["1775560189.41808.wav", "audio (1).mp3", "audio.mp3"] # Initialize WandB wandb.init(project="whisper-evaluation", name="Base_vs_LoRAMerged_Benchmark") table = wandb.Table(columns=["Audio File", "Base Time (s)", "LoRA Time (s)", "Base Transcript", "LoRA Transcript"]) results = {} # Check paths for af in audio_files: if not (root / af).exists(): logger.error(f"Cannot find audio file: {af} in root. Make sure it exists.") sys.exit(1) # Load Base logger.info("--- Loading Base Model ---") base_transcriber = WhisperTranscriber(model_path="openai/whisper-large-v3", language="arabic", task="transcribe") for af in audio_files: logger.info(f"[Base] Transcribing {af}...") start = time.time() text = base_transcriber.transcribe(str(root / af)) latency = time.time() - start results[af] = {"base_time": latency, "base_text": text} # Clear memory to prevent OOM del base_transcriber import torch import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Load Merged LoRA logger.info("--- Loading LoRA Merged Model ---") merged_path = root / "outputs/checkpoints/merged_model" lora_transcriber = WhisperTranscriber(model_path=str(merged_path), language="arabic", task="transcribe") for af in audio_files: logger.info(f"[LoRA] Transcribing {af}...") start = time.time() text = lora_transcriber.transcribe(str(root / af)) latency = time.time() - start results[af]["lora_time"] = latency results[af]["lora_text"] = text for af in audio_files: r = results[af] table.add_data(af, f"{r['base_time']:.2f}", f"{r['lora_time']:.2f}", r["base_text"], r["lora_text"]) wandb.log({"Benchmark Results": table}) wandb.finish() logger.info("Benchmark complete! View the results in your WandB dashboard.") if __name__ == "__main__": main()