Speach-To-Text / scripts /evaluate_calls_wandb.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
import os
import sys
import time
import wandb
import logging
import torch
import gc
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.inference.transcribe import WhisperTranscriber
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def main():
root = Path(__file__).parent.parent
# Target only 2 files for a "Quick Comparison" as requested
audio_files = [
"1775560189.41808.wav",
"audio.mp3"
]
# Filter only existing files
valid_audio = []
for af in audio_files:
p = root / af
if p.exists():
valid_audio.append(p)
else:
logger.warning(f"File not found: {af}")
if not valid_audio:
logger.error("No valid audio files found to test.")
return
# Initialize WandB
wandb.init(project="whisper-evaluation", name=f"Base_vs_Merged_Calls_{int(time.time())}")
table = wandb.Table(columns=["Audio Clip", "Filename", "Base Transcription", "Merged Transcription", "Base Time (s)", "Merged Time (s)", "Delta"])
results = {}
def cleanup():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Pass 1: Base Model
logger.info("--- Phase 1: Transcribing with Base Model (whisper-large-v3) ---")
try:
base_t = WhisperTranscriber(model_path="openai/whisper-large-v3")
for ap in valid_audio:
fname = ap.name
logger.info(f"Processing {fname} with Base...")
start = time.time()
text = base_t.transcribe(ap)
latency = time.time() - start
results[fname] = {
"base_text": text,
"base_time": latency
}
# Free memory
del base_t
cleanup()
except Exception as e:
logger.error(f"Error during Base Model phase: {e}")
# Pass 2: Merged Model
logger.info("--- Phase 2: Transcribing with Merged Model (Fine-tuned) ---")
merged_model_path = root / "outputs" / "checkpoints" / "merged_model"
if not merged_model_path.exists():
logger.error(f"Merged model not found at {merged_model_path}")
return
try:
merged_t = WhisperTranscriber(model_path=str(merged_model_path))
for ap in valid_audio:
fname = ap.name
logger.info(f"Processing {fname} with Merged...")
start = time.time()
text = merged_t.transcribe(ap)
latency = time.time() - start
r = results[fname]
r["merged_text"] = text
r["merged_time"] = latency
# Log to table
delta = r["base_time"] - r["merged_time"]
table.add_data(
wandb.Audio(str(ap), sample_rate=16000),
fname,
r["base_text"],
r["merged_text"],
f"{r['base_time']:.2f}",
f"{r['merged_time']:.2f}",
f"{delta:+.2f}"
)
del merged_t
cleanup()
except Exception as e:
logger.error(f"Error during Merged Model phase: {e}")
wandb.log({"Model Comparison Results": table})
wandb.finish()
logger.info("Evaluation complete! Check your W&B dashboard for the Side-by-Side comparison.")
if __name__ == "__main__":
main()