Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| import gradio as gr | |
| import models as MOD | |
| import process_data as PD | |
| from transformers import pipeline | |
| from huggingface_hub import hf_hub_download | |
| model_master = { | |
| "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448, | |
| "data_process_func": "process_ssl_assist_input", | |
| "note": "This model is trained only on ASVSpoof 2024 training data.", | |
| "model_class": "Model", | |
| "model_checkpoint": "ssl_aasist_epoch_7.pth"}, | |
| "AASIST": {"eer_threshold": 1.8018419742584229, | |
| "data_process_func": "process_assist_input", | |
| "note": "This model is trained on ASVSpoof 2024 training data.", | |
| "model_class":"AASIST_Model", | |
| "model_checkpoint": "orig_aasist_epoch_1.pth"} | |
| } | |
| model = MOD.Model(None, "cpu") | |
| base_model_file = hf_hub_download("arnabdas8901/aasist-trained-asvspoof2024", filename="ssl_aasist_epoch_7.pth") | |
| model.load_state_dict(torch.load(base_model_file, map_location="cpu")) | |
| model.eval() | |
| loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)" | |
| def process(file, type): | |
| global model | |
| global loaded_model | |
| inp = getattr(PD, model_master[type]["data_process_func"])(file) | |
| if not loaded_model == type: | |
| model = getattr(MOD, model_master[type]["model_class"])(None, "cpu") | |
| model_file = hf_hub_download("arnabdas8901/aasist-trained-asvspoof2024", filename=model_master[type]["model_checkpoint"]) | |
| model.load_state_dict(torch.load(model_file, map_location="cpu")) | |
| model.eval() | |
| loaded_model = type | |
| op = model(inp).detach().squeeze()[1].item() | |
| output_json = {} | |
| output_json["decision_score"] = str(op) | |
| output_json["model_threshold"] = str(model_master[type]["eer_threshold"]) | |
| output_json["optional_note"] = "1. Any score below threshold is indicative of fake. \n2. {}".format(model_master[type]["note"]) | |
| response_text = json.dumps(output_json, indent=4) | |
| """response_text = "Decision score: {} \nDecision threshold: {} \nNotes: 1. Any score below threshold is indicative of fake. \n2. {} ".format( | |
| str(op), str(model_master[type]["eer_threshold"]), model_master[type]["note"])""" | |
| return response_text | |
| demo = gr.Blocks() | |
| file_proc = gr.Interface( | |
| fn=process, | |
| inputs=[ | |
| gr.Audio(sources=["upload"], label="Audio file", type="filepath"), | |
| gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"), | |
| ], | |
| outputs="text", | |
| title="Find the Fake: Analyze 'Real' or 'Fake'.", | |
| description=( | |
| "Analyze fake or real with a click of a button. Upload a .wav or .flac file." | |
| ), | |
| examples=[ | |
| ["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"], | |
| ["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"], | |
| ["./bonafide.flac", "AASIST"], | |
| ["./fake.flac", "AASIST"], | |
| ], | |
| cache_examples=True, | |
| allow_flagging="never", | |
| ) | |
| ##################################################################################### | |
| # For ASR interface | |
| pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model="openai/whisper-large-v3", | |
| chunk_length_s=30, | |
| device="cpu", | |
| ) | |
| def transcribe(inputs): | |
| if inputs is None: | |
| raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
| op = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=False, return_language=True) | |
| lang = op["chunks"][0]["language"] | |
| text = op["text"] | |
| return lang, text | |
| transcribe_proc = gr.Interface( | |
| fn = transcribe, | |
| inputs = [ | |
| gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True) | |
| ], | |
| outputs=[ | |
| gr.Text(label="Predicted Language", info="Language identification is performed automatically."), | |
| gr.Text(label="Predicted transcription", info="Best hypothesis."), | |
| ], | |
| title="Transcribe Anything.", | |
| description=( | |
| "Automatactic language identification and transcription service by Whisper Large V3. Upload a .wav or .flac file." | |
| ), | |
| allow_flagging="never" | |
| ) | |
| with demo: | |
| gr.TabbedInterface([file_proc, transcribe_proc], ["Analyze Audio File", "Transcribe Audio File"]) | |
| demo.queue(max_size=10) | |
| demo.launch(share=True) | |