Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| import wandb | |
| import logging | |
| import torch | |
| import gc | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.inference.transcribe import WhisperTranscriber | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| root = Path(__file__).parent.parent | |
| # Target only 2 files for a "Quick Comparison" as requested | |
| audio_files = [ | |
| "1775560189.41808.wav", | |
| "audio.mp3" | |
| ] | |
| # Filter only existing files | |
| valid_audio = [] | |
| for af in audio_files: | |
| p = root / af | |
| if p.exists(): | |
| valid_audio.append(p) | |
| else: | |
| logger.warning(f"File not found: {af}") | |
| if not valid_audio: | |
| logger.error("No valid audio files found to test.") | |
| return | |
| # Initialize WandB | |
| wandb.init(project="whisper-evaluation", name=f"Base_vs_Merged_Calls_{int(time.time())}") | |
| table = wandb.Table(columns=["Audio Clip", "Filename", "Base Transcription", "Merged Transcription", "Base Time (s)", "Merged Time (s)", "Delta"]) | |
| results = {} | |
| def cleanup(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Pass 1: Base Model | |
| logger.info("--- Phase 1: Transcribing with Base Model (whisper-large-v3) ---") | |
| try: | |
| base_t = WhisperTranscriber(model_path="openai/whisper-large-v3") | |
| for ap in valid_audio: | |
| fname = ap.name | |
| logger.info(f"Processing {fname} with Base...") | |
| start = time.time() | |
| text = base_t.transcribe(ap) | |
| latency = time.time() - start | |
| results[fname] = { | |
| "base_text": text, | |
| "base_time": latency | |
| } | |
| # Free memory | |
| del base_t | |
| cleanup() | |
| except Exception as e: | |
| logger.error(f"Error during Base Model phase: {e}") | |
| # Pass 2: Merged Model | |
| logger.info("--- Phase 2: Transcribing with Merged Model (Fine-tuned) ---") | |
| merged_model_path = root / "outputs" / "checkpoints" / "merged_model" | |
| if not merged_model_path.exists(): | |
| logger.error(f"Merged model not found at {merged_model_path}") | |
| return | |
| try: | |
| merged_t = WhisperTranscriber(model_path=str(merged_model_path)) | |
| for ap in valid_audio: | |
| fname = ap.name | |
| logger.info(f"Processing {fname} with Merged...") | |
| start = time.time() | |
| text = merged_t.transcribe(ap) | |
| latency = time.time() - start | |
| r = results[fname] | |
| r["merged_text"] = text | |
| r["merged_time"] = latency | |
| # Log to table | |
| delta = r["base_time"] - r["merged_time"] | |
| table.add_data( | |
| wandb.Audio(str(ap), sample_rate=16000), | |
| fname, | |
| r["base_text"], | |
| r["merged_text"], | |
| f"{r['base_time']:.2f}", | |
| f"{r['merged_time']:.2f}", | |
| f"{delta:+.2f}" | |
| ) | |
| del merged_t | |
| cleanup() | |
| except Exception as e: | |
| logger.error(f"Error during Merged Model phase: {e}") | |
| wandb.log({"Model Comparison Results": table}) | |
| wandb.finish() | |
| logger.info("Evaluation complete! Check your W&B dashboard for the Side-by-Side comparison.") | |
| if __name__ == "__main__": | |
| main() | |