| import os |
| import torch |
| import torchaudio |
| import librosa |
| import streamlit as st |
| from huggingface_hub import login |
| from transformers import AutoProcessor, AutoModelForCTC |
| import numpy as np |
| |
| |
| |
| HF_TOKEN = os.getenv("hf_token") |
|
|
| if HF_TOKEN is None: |
| raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.") |
|
|
| login(token=HF_TOKEN) |
|
|
| |
| |
| |
| MODEL_NAME = "deepl-project/conformer-finetunning" |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForCTC.from_pretrained(MODEL_NAME) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.to(device) |
| print(f"✅ Conformer Model loaded on {device}") |
|
|
| |
| |
| |
| st.sidebar.title("🔧 Fine-Tuning Hyperparameters") |
| num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3) |
| learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5) |
| batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8) |
| attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1) |
|
|
| |
| |
| |
| st.title("🎙️ Speech-to-Text ASR Conformer Model Finetunned on Libri Speech with Security Features 🎶") |
|
|
| audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) |
|
|
| if audio_file: |
| audio_path = "temp_audio.wav" |
| with open(audio_path, "wb") as f: |
| f.write(audio_file.read()) |
|
|
| speech, sr = librosa.load(audio_path, sr=16000) |
| |
| |
| adversarial_speech = speech + (attack_strength * np.random.randn(*speech.shape)) |
| adversarial_speech = np.clip(adversarial_speech, -1.0, 1.0) |
| |
| inputs = processor(adversarial_speech, sampling_rate=sr, return_tensors="pt", padding=True) |
| input_values = inputs.input_values.to(device) |
| |
| with torch.no_grad(): |
| logits = model(input_values).logits |
| |
| predicted_ids = torch.argmax(logits, dim=-1) |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
| |
| if attack_strength > 0.2: |
| st.warning("⚠️ Adversarial attack detected! Transcription may be affected.") |
| |
| st.success("📄 Secure Transcription:") |
| st.write(transcription[0]) |
|
|