File size: 3,539 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import sys
import time
import wandb
import logging
import torch
import gc
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.inference.transcribe import WhisperTranscriber

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

def main():
    root = Path(__file__).parent.parent
    # Target only 2 files for a "Quick Comparison" as requested
    audio_files = [
        "1775560189.41808.wav",
        "audio.mp3"
    ]
    
    # Filter only existing files
    valid_audio = []
    for af in audio_files:
        p = root / af
        if p.exists():
            valid_audio.append(p)
        else:
            logger.warning(f"File not found: {af}")

    if not valid_audio:
        logger.error("No valid audio files found to test.")
        return

    # Initialize WandB
    wandb.init(project="whisper-evaluation", name=f"Base_vs_Merged_Calls_{int(time.time())}")
    
    table = wandb.Table(columns=["Audio Clip", "Filename", "Base Transcription", "Merged Transcription", "Base Time (s)", "Merged Time (s)", "Delta"])
    
    results = {}

    def cleanup():
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Pass 1: Base Model
    logger.info("--- Phase 1: Transcribing with Base Model (whisper-large-v3) ---")
    try:
        base_t = WhisperTranscriber(model_path="openai/whisper-large-v3")
        for ap in valid_audio:
            fname = ap.name
            logger.info(f"Processing {fname} with Base...")
            start = time.time()
            text = base_t.transcribe(ap)
            latency = time.time() - start
            results[fname] = {
                "base_text": text,
                "base_time": latency
            }
        
        # Free memory
        del base_t
        cleanup()
    except Exception as e:
        logger.error(f"Error during Base Model phase: {e}")

    # Pass 2: Merged Model
    logger.info("--- Phase 2: Transcribing with Merged Model (Fine-tuned) ---")
    merged_model_path = root / "outputs" / "checkpoints" / "merged_model"
    if not merged_model_path.exists():
        logger.error(f"Merged model not found at {merged_model_path}")
        return

    try:
        merged_t = WhisperTranscriber(model_path=str(merged_model_path))
        for ap in valid_audio:
            fname = ap.name
            logger.info(f"Processing {fname} with Merged...")
            start = time.time()
            text = merged_t.transcribe(ap)
            latency = time.time() - start
            
            r = results[fname]
            r["merged_text"] = text
            r["merged_time"] = latency
            
            # Log to table
            delta = r["base_time"] - r["merged_time"]
            table.add_data(
                wandb.Audio(str(ap), sample_rate=16000),
                fname,
                r["base_text"],
                r["merged_text"],
                f"{r['base_time']:.2f}",
                f"{r['merged_time']:.2f}",
                f"{delta:+.2f}"
            )
        
        del merged_t
        cleanup()
    except Exception as e:
        logger.error(f"Error during Merged Model phase: {e}")

    wandb.log({"Model Comparison Results": table})
    wandb.finish()
    logger.info("Evaluation complete! Check your W&B dashboard for the Side-by-Side comparison.")

if __name__ == "__main__":
    main()