Spaces:
Sleeping
Sleeping
File size: 2,612 Bytes
0db822c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | """
Benchmark inference performance and quality comparing Base vs Merged models.
Logs natively to WandB for side-by-side visualization.
"""
import sys
import time
import wandb
import logging
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.inference.transcribe import WhisperTranscriber
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def main():
root = Path(__file__).parent.parent
audio_files = ["1775560189.41808.wav", "audio (1).mp3", "audio.mp3"]
# Initialize WandB
wandb.init(project="whisper-evaluation", name="Base_vs_LoRAMerged_Benchmark")
table = wandb.Table(columns=["Audio File", "Base Time (s)", "LoRA Time (s)", "Base Transcript", "LoRA Transcript"])
results = {}
# Check paths
for af in audio_files:
if not (root / af).exists():
logger.error(f"Cannot find audio file: {af} in root. Make sure it exists.")
sys.exit(1)
# Load Base
logger.info("--- Loading Base Model ---")
base_transcriber = WhisperTranscriber(model_path="openai/whisper-large-v3", language="arabic", task="transcribe")
for af in audio_files:
logger.info(f"[Base] Transcribing {af}...")
start = time.time()
text = base_transcriber.transcribe(str(root / af))
latency = time.time() - start
results[af] = {"base_time": latency, "base_text": text}
# Clear memory to prevent OOM
del base_transcriber
import torch
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load Merged LoRA
logger.info("--- Loading LoRA Merged Model ---")
merged_path = root / "outputs/checkpoints/merged_model"
lora_transcriber = WhisperTranscriber(model_path=str(merged_path), language="arabic", task="transcribe")
for af in audio_files:
logger.info(f"[LoRA] Transcribing {af}...")
start = time.time()
text = lora_transcriber.transcribe(str(root / af))
latency = time.time() - start
results[af]["lora_time"] = latency
results[af]["lora_text"] = text
for af in audio_files:
r = results[af]
table.add_data(af, f"{r['base_time']:.2f}", f"{r['lora_time']:.2f}", r["base_text"], r["lora_text"])
wandb.log({"Benchmark Results": table})
wandb.finish()
logger.info("Benchmark complete! View the results in your WandB dashboard.")
if __name__ == "__main__":
main()
|