File size: 2,612 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Benchmark inference performance and quality comparing Base vs Merged models.
Logs natively to WandB for side-by-side visualization.
"""
import sys
import time
import wandb
import logging
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))
from src.inference.transcribe import WhisperTranscriber

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

def main():
    root = Path(__file__).parent.parent
    audio_files = ["1775560189.41808.wav", "audio (1).mp3", "audio.mp3"]
    
    # Initialize WandB
    wandb.init(project="whisper-evaluation", name="Base_vs_LoRAMerged_Benchmark")
    
    table = wandb.Table(columns=["Audio File", "Base Time (s)", "LoRA Time (s)", "Base Transcript", "LoRA Transcript"])
    results = {}
    
    # Check paths
    for af in audio_files:
        if not (root / af).exists():
            logger.error(f"Cannot find audio file: {af} in root. Make sure it exists.")
            sys.exit(1)
            
    # Load Base
    logger.info("--- Loading Base Model ---")
    base_transcriber = WhisperTranscriber(model_path="openai/whisper-large-v3", language="arabic", task="transcribe")
    for af in audio_files:
        logger.info(f"[Base] Transcribing {af}...")
        start = time.time()
        text = base_transcriber.transcribe(str(root / af))
        latency = time.time() - start
        results[af] = {"base_time": latency, "base_text": text}
    
    # Clear memory to prevent OOM
    del base_transcriber
    import torch
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Load Merged LoRA
    logger.info("--- Loading LoRA Merged Model ---")
    merged_path = root / "outputs/checkpoints/merged_model"
    lora_transcriber = WhisperTranscriber(model_path=str(merged_path), language="arabic", task="transcribe")
    
    for af in audio_files:
        logger.info(f"[LoRA] Transcribing {af}...")
        start = time.time()
        text = lora_transcriber.transcribe(str(root / af))
        latency = time.time() - start
        results[af]["lora_time"] = latency
        results[af]["lora_text"] = text
        
    for af in audio_files:
        r = results[af]
        table.add_data(af, f"{r['base_time']:.2f}", f"{r['lora_time']:.2f}", r["base_text"], r["lora_text"])
        
    wandb.log({"Benchmark Results": table})
    wandb.finish()
    logger.info("Benchmark complete! View the results in your WandB dashboard.")

if __name__ == "__main__":
    main()