Spaces:
Sleeping
Sleeping
| """ | |
| Benchmark inference performance and quality comparing Base vs Merged models. | |
| Logs natively to WandB for side-by-side visualization. | |
| """ | |
| import sys | |
| import time | |
| import wandb | |
| import logging | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.inference.transcribe import WhisperTranscriber | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| root = Path(__file__).parent.parent | |
| audio_files = ["1775560189.41808.wav", "audio (1).mp3", "audio.mp3"] | |
| # Initialize WandB | |
| wandb.init(project="whisper-evaluation", name="Base_vs_LoRAMerged_Benchmark") | |
| table = wandb.Table(columns=["Audio File", "Base Time (s)", "LoRA Time (s)", "Base Transcript", "LoRA Transcript"]) | |
| results = {} | |
| # Check paths | |
| for af in audio_files: | |
| if not (root / af).exists(): | |
| logger.error(f"Cannot find audio file: {af} in root. Make sure it exists.") | |
| sys.exit(1) | |
| # Load Base | |
| logger.info("--- Loading Base Model ---") | |
| base_transcriber = WhisperTranscriber(model_path="openai/whisper-large-v3", language="arabic", task="transcribe") | |
| for af in audio_files: | |
| logger.info(f"[Base] Transcribing {af}...") | |
| start = time.time() | |
| text = base_transcriber.transcribe(str(root / af)) | |
| latency = time.time() - start | |
| results[af] = {"base_time": latency, "base_text": text} | |
| # Clear memory to prevent OOM | |
| del base_transcriber | |
| import torch | |
| import gc | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Load Merged LoRA | |
| logger.info("--- Loading LoRA Merged Model ---") | |
| merged_path = root / "outputs/checkpoints/merged_model" | |
| lora_transcriber = WhisperTranscriber(model_path=str(merged_path), language="arabic", task="transcribe") | |
| for af in audio_files: | |
| logger.info(f"[LoRA] Transcribing {af}...") | |
| start = time.time() | |
| text = lora_transcriber.transcribe(str(root / af)) | |
| latency = time.time() - start | |
| results[af]["lora_time"] = latency | |
| results[af]["lora_text"] = text | |
| for af in audio_files: | |
| r = results[af] | |
| table.add_data(af, f"{r['base_time']:.2f}", f"{r['lora_time']:.2f}", r["base_text"], r["lora_text"]) | |
| wandb.log({"Benchmark Results": table}) | |
| wandb.finish() | |
| logger.info("Benchmark complete! View the results in your WandB dashboard.") | |
| if __name__ == "__main__": | |
| main() | |