| from flask import Flask, render_template, request, jsonify |
| from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForXVector |
| import torchaudio |
| import torch |
| import io |
| import librosa |
| from scipy.spatial.distance import cosine |
| import numpy as np |
| import os |
| |
| |
|
|
| app = Flask(__name__, static_url_path='/static') |
|
|
| |
| mp3_file_path = "arnold.mp3" |
|
|
| |
| mp3_file_path2 = 'arnold2.wav' |
|
|
| flag1="" |
| flag2="" |
|
|
| with open("flag1.txt") as f: |
| flag1=f.read() |
| with open("flag2.txt") as f: |
| flag2=f.read() |
|
|
| |
| themodel = "microsoft/unispeech-sat-large-sv" |
| if os.path.exists("model"): |
| themodel = "model" |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(themodel) |
| model = UniSpeechSatForXVector.from_pretrained(themodel) |
|
|
| |
| def preprocess_audio(audio_data): |
| waveform, sample_rate = torchaudio.load(audio_data) |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
| if sample_rate != 16000: |
| waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) |
| waveform = waveform.squeeze().numpy() |
| return waveform |
|
|
| @app.route('/') |
| def index(): |
| return render_template('index.html') |
|
|
| @app.route('/chal2') |
| def chal2(): |
| return render_template('chal2.html') |
|
|
| |
| |
| |
| |
| @app.route('/compare_audio', methods=['POST']) |
| def compare_audio(): |
| try: |
| |
| recorded_audio = request.files['audio_data'] |
|
|
| |
| audio_data = preprocess_audio(recorded_audio) |
| inputs = feature_extractor(audio_data, return_tensors="pt") |
| embeddings = model(**inputs).embeddings |
| embeddings_normalized = torch.nn.functional.normalize(embeddings, dim=-1).cpu() |
|
|
| |
| mp3_audio = preprocess_audio(mp3_file_path) |
| mp3_inputs = feature_extractor(mp3_audio, return_tensors="pt") |
| mp3_embeddings = model(**mp3_inputs).embeddings |
| mp3_embeddings_normalized = torch.nn.functional.normalize(mp3_embeddings, dim=-1).cpu() |
|
|
| |
| cosine_sim = torch.nn.CosineSimilarity(dim=-1) |
| similarity = cosine_sim(embeddings_normalized, mp3_embeddings_normalized).item() |
|
|
| similarity = round(similarity, 3) |
|
|
| threshold = 0.89 |
| if similarity < threshold: |
| result = "Authorization Failed! " + str(similarity) + " < 0.890<br>Do your best Terminator impression" |
| else: |
| result = "Good job! Match: " + str(similarity) + "<br>" + flag1 + "<br><a href='/chal2'>Click here to open the next challenge</a>" |
|
|
| return jsonify({'result': result}) |
| except Exception as e: |
| print("Caught: "+str(e)) |
| return jsonify({'error': 'An error occurred during audio comparison. Im fragile please dont abuse.' }) |
|
|
| def extract_mfcc(audio_bytes): |
| |
| waveform = preprocess_audio2(audio_bytes) |
| |
| |
| mfcc = librosa.feature.mfcc(y=waveform, sr=16000, n_mfcc=13) |
|
|
| return mfcc |
|
|
| def preprocess_audio2(audio_bytes): |
| |
| waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) |
| |
| |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
| |
| if sample_rate != 16000: |
| waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) |
|
|
| |
| waveform, _ = librosa.effects.trim(waveform, top_db=20) |
| |
| waveform = waveform.squeeze().numpy() |
|
|
| return waveform |
|
|
| @app.route('/compare_audio2', methods=['POST']) |
| def compare_audio2(): |
| try: |
| recorded_audio = request.files['audio_data'].read() |
| mp3_audio = open(mp3_file_path2, 'rb').read() |
|
|
| |
| mfcc1 = extract_mfcc(recorded_audio) |
| mfcc2 = extract_mfcc(mp3_audio) |
| similarity = 1 - cosine(np.mean(mfcc1, axis=1), np.mean(mfcc2, axis=1)) |
| similarity = round(similarity, 3) |
| if similarity < 0.940: |
| result = "Authorization Failed! " + str(similarity) + " < 0.940<br>Say: 'With great power comes great responsibility' as Arnold Schwarzenegger" |
| else: |
| result = "Good job! Match: " + str(similarity) + "<br>" + flag2 |
|
|
| return jsonify({'result': result}) |
| except Exception as e: |
| print("Caught: "+str(e)) |
| return jsonify({'error': 'An error occurred during audio comparison. Im fragile please dont abuse.'}) |
|
|
| if __name__ == '__main__': |
| app.run(host="0.0.0.0", port=8080, debug=True) |