Spaces:
Runtime error
Runtime error
| from typing import Dict | |
| import torch | |
| import gradio as gr | |
| import whisper | |
| from whisper.tokenizer import get_tokenizer | |
| import classify | |
| from datasets import load_dataset | |
| model_cache = {} | |
| def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]: | |
| class_names = class_names.split(",") | |
| tokenizer = get_tokenizer(multilingual=".en" not in model_name) | |
| print("#########", model_name) | |
| if model_name not in model_cache: | |
| model = whisper.load_model(model_name) | |
| model_cache[model_name] = model | |
| else: | |
| model = model_cache[model_name] | |
| print("#### Model ####", model) | |
| internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs( | |
| model=model, | |
| class_names=class_names, | |
| tokenizer=tokenizer, | |
| ) | |
| audio_features = classify.calculate_audio_features(audio_path, model) | |
| average_logprobs = classify.calculate_average_logprobs( | |
| model=model, | |
| audio_features=audio_features, | |
| class_names=class_names, | |
| tokenizer=tokenizer, | |
| ) | |
| average_logprobs -= internal_lm_average_logprobs | |
| scores = average_logprobs.softmax(-1).tolist() | |
| return {class_name: score for class_name, score in zip(class_names, scores)} | |
| def main(): | |
| CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking],[popping],[sneezing],[sigh],[slurping],[mouth sounds],[clearing thoat]," | |
| AUDIO_PATHS = [ | |
| "./data/(dog)1-100032-A-0.wav", | |
| "./data/(helicopter)1-181071-A-40.wav", | |
| "./data/(laughing)1-1791-A-26.wav", | |
| "./data/(chirping_birds)1-34495-A-14.wav", | |
| "./data/(clock_tick)1-21934-A-38.wav", | |
| "./data/clears_throat1.wav", | |
| "./data/mouth_sounds1.wav", | |
| "./data/pop1.wav", | |
| "./data/sigh1.wav", | |
| "./data/slurp1.wav", | |
| ] | |
| EXAMPLES = [] | |
| for audio_path in AUDIO_PATHS: | |
| EXAMPLES.append([audio_path, CLASS_NAMES, "small"]) | |
| DESCRIPTION = ( | |
| '<div style="text-align: center;">' | |
| "<p>This demo allows you to try out zero-shot audio classification using " | |
| "<a href=https://github.com/openai/whisper>Whisper</a>.</p>" | |
| "<p>Github: <a href=https://github.com/jumon/zac>https://github.com/jumon/zac</a></p>" | |
| "<p>Example audio files are from the <a href=https://github.com/karolpiczak/ESC-50>ESC-50" | |
| "</a> dataset (CC BY-NC 3.0).</p></div>" | |
| ) | |
| demo = gr.Interface( | |
| fn=zero_shot_classify, | |
| inputs=[ | |
| gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"), | |
| gr.Textbox(lines=1, label="Candidate class names (comma-separated)"), | |
| gr.Radio( | |
| choices=["tiny", "base", "small", "medium", "large"], | |
| value="small", | |
| label="Model Name", | |
| ), | |
| ], | |
| outputs="label", | |
| examples=EXAMPLES, | |
| title="Zero-shot Audio Classification using Whisper", | |
| description=DESCRIPTION, | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |