Speach-To-Text / scripts /benchmark_models.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Benchmark inference performance and quality comparing Base vs Merged models.
Logs natively to WandB for side-by-side visualization.
"""
import sys
import time
import wandb
import logging
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.inference.transcribe import WhisperTranscriber
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def main():
root = Path(__file__).parent.parent
audio_files = ["1775560189.41808.wav", "audio (1).mp3", "audio.mp3"]
# Initialize WandB
wandb.init(project="whisper-evaluation", name="Base_vs_LoRAMerged_Benchmark")
table = wandb.Table(columns=["Audio File", "Base Time (s)", "LoRA Time (s)", "Base Transcript", "LoRA Transcript"])
results = {}
# Check paths
for af in audio_files:
if not (root / af).exists():
logger.error(f"Cannot find audio file: {af} in root. Make sure it exists.")
sys.exit(1)
# Load Base
logger.info("--- Loading Base Model ---")
base_transcriber = WhisperTranscriber(model_path="openai/whisper-large-v3", language="arabic", task="transcribe")
for af in audio_files:
logger.info(f"[Base] Transcribing {af}...")
start = time.time()
text = base_transcriber.transcribe(str(root / af))
latency = time.time() - start
results[af] = {"base_time": latency, "base_text": text}
# Clear memory to prevent OOM
del base_transcriber
import torch
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load Merged LoRA
logger.info("--- Loading LoRA Merged Model ---")
merged_path = root / "outputs/checkpoints/merged_model"
lora_transcriber = WhisperTranscriber(model_path=str(merged_path), language="arabic", task="transcribe")
for af in audio_files:
logger.info(f"[LoRA] Transcribing {af}...")
start = time.time()
text = lora_transcriber.transcribe(str(root / af))
latency = time.time() - start
results[af]["lora_time"] = latency
results[af]["lora_text"] = text
for af in audio_files:
r = results[af]
table.add_data(af, f"{r['base_time']:.2f}", f"{r['lora_time']:.2f}", r["base_text"], r["lora_text"])
wandb.log({"Benchmark Results": table})
wandb.finish()
logger.info("Benchmark complete! View the results in your WandB dashboard.")
if __name__ == "__main__":
main()