import os import sys import time import wandb import logging import torch import gc from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from src.inference.transcribe import WhisperTranscriber logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) def main(): root = Path(__file__).parent.parent # Target only 2 files for a "Quick Comparison" as requested audio_files = [ "1775560189.41808.wav", "audio.mp3" ] # Filter only existing files valid_audio = [] for af in audio_files: p = root / af if p.exists(): valid_audio.append(p) else: logger.warning(f"File not found: {af}") if not valid_audio: logger.error("No valid audio files found to test.") return # Initialize WandB wandb.init(project="whisper-evaluation", name=f"Base_vs_Merged_Calls_{int(time.time())}") table = wandb.Table(columns=["Audio Clip", "Filename", "Base Transcription", "Merged Transcription", "Base Time (s)", "Merged Time (s)", "Delta"]) results = {} def cleanup(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Pass 1: Base Model logger.info("--- Phase 1: Transcribing with Base Model (whisper-large-v3) ---") try: base_t = WhisperTranscriber(model_path="openai/whisper-large-v3") for ap in valid_audio: fname = ap.name logger.info(f"Processing {fname} with Base...") start = time.time() text = base_t.transcribe(ap) latency = time.time() - start results[fname] = { "base_text": text, "base_time": latency } # Free memory del base_t cleanup() except Exception as e: logger.error(f"Error during Base Model phase: {e}") # Pass 2: Merged Model logger.info("--- Phase 2: Transcribing with Merged Model (Fine-tuned) ---") merged_model_path = root / "outputs" / "checkpoints" / "merged_model" if not merged_model_path.exists(): logger.error(f"Merged model not found at {merged_model_path}") return try: merged_t = WhisperTranscriber(model_path=str(merged_model_path)) for ap in valid_audio: fname = ap.name logger.info(f"Processing {fname} with Merged...") start = time.time() text = merged_t.transcribe(ap) latency = time.time() - start r = results[fname] r["merged_text"] = text r["merged_time"] = latency # Log to table delta = r["base_time"] - r["merged_time"] table.add_data( wandb.Audio(str(ap), sample_rate=16000), fname, r["base_text"], r["merged_text"], f"{r['base_time']:.2f}", f"{r['merged_time']:.2f}", f"{delta:+.2f}" ) del merged_t cleanup() except Exception as e: logger.error(f"Error during Merged Model phase: {e}") wandb.log({"Model Comparison Results": table}) wandb.finish() logger.info("Evaluation complete! Check your W&B dashboard for the Side-by-Side comparison.") if __name__ == "__main__": main()