Spaces:
Runtime error
Runtime error
| import os | |
| import pytest | |
| import torch | |
| import whisper | |
| from whisper.tokenizer import get_tokenizer | |
| def test_transcribe(model_name: str): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = whisper.load_model(model_name).to(device) | |
| audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") | |
| language = "en" if model_name.endswith(".en") else None | |
| result = model.transcribe( | |
| audio_path, language=language, temperature=0.0, word_timestamps=True | |
| ) | |
| assert result["language"] == "en" | |
| assert result["text"] == "".join([s["text"] for s in result["segments"]]) | |
| transcription = result["text"].lower() | |
| assert "my fellow americans" in transcription | |
| assert "your country" in transcription | |
| assert "do for you" in transcription | |
| tokenizer = get_tokenizer(model.is_multilingual) | |
| all_tokens = [t for s in result["segments"] for t in s["tokens"]] | |
| assert tokenizer.decode(all_tokens) == result["text"] | |
| assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>") | |
| timing_checked = False | |
| for segment in result["segments"]: | |
| for timing in segment["words"]: | |
| assert timing["start"] < timing["end"] | |
| if timing["word"].strip(" ,") == "Americans": | |
| assert timing["start"] <= 1.8 | |
| assert timing["end"] >= 1.8 | |
| timing_checked = True | |
| assert timing_checked | |