msmaje's picture
Update app.py
5c09f15 verified
raw
history blame
8.34 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import librosa
import numpy as np
from sklearn.preprocessing import LabelEncoder
import os
import warnings
warnings.filterwarnings('ignore')
# Model Definition (same as your training script)
class TransferLearningModel(nn.Module):
def __init__(self, num_classes):
super(TransferLearningModel, self).__init__()
# Use non-pretrained ResNet18 for deployment
self.resnet = models.resnet18(pretrained=False)
# Modify first conv layer for single channel input (MFCC)
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Modify final layer for our number of classes
num_ftrs = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_ftrs, num_classes)
# Add dropout for regularization
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# Get features from ResNet (excluding final layer)
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
x = self.resnet.layer4(x)
x = self.resnet.avgpool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.resnet.fc(x)
return x
# Feature extraction function
def extract_features(audio_data, sample_rate, max_pad_len=174):
"""
Extract MFCC features from audio data
"""
try:
# Extract MFCC features
mfccs = librosa.feature.mfcc(y=audio_data, sr=sample_rate, n_mfcc=40)
# Pad or truncate to fixed length
pad_width = max_pad_len - mfccs.shape[1]
if pad_width > 0:
mfccs = np.pad(mfccs, pad_width=((0, 0), (0, pad_width)), mode='constant')
else:
mfccs = mfccs[:, :max_pad_len]
return mfccs
except Exception as e:
print(f"Error extracting features: {str(e)}")
return None
# Initialize model and label encoder
device = torch.device('cpu') # Use CPU for deployment
authorized_users = ['user1', 'user2', 'user3', 'user4', 'user5', 'user6', 'user7']
# Initialize label encoder with known classes
label_encoder = LabelEncoder()
# You'll need to update this with your actual user classes
all_users = ['user1', 'user2', 'user3', 'user4', 'user5', 'user6', 'user7'] # Add all your users here
label_encoder.fit(all_users)
# Load model
model = None
try:
# Load the full model
model = torch.load('voice_recognition_fullmodel.pth', map_location=device)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback: create model and load state dict
try:
model = TransferLearningModel(len(all_users))
model.load_state_dict(torch.load('voice_recognition_fullmodel.pth', map_location=device))
model.eval()
print("Model loaded with state dict!")
except Exception as e2:
print(f"Error loading model with state dict: {e2}")
def predict_voice(audio_file, confidence_threshold=0.7):
"""
Predict voice and determine access
"""
if model is None:
return "❌ Model not loaded", "Error", 0.0, "Unable to load model"
try:
# Load audio file
if audio_file is None:
return "❌ No audio file provided", "Error", 0.0, "Please upload an audio file"
# Load audio data
audio_data, sample_rate = librosa.load(audio_file, res_type='kaiser_fast')
# Extract features
features = extract_features(audio_data, sample_rate)
if features is None:
return "❌ Could not extract features", "Error", 0.0, "Feature extraction failed"
# Prepare input tensor
features = torch.tensor(features, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(features)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
predicted_user = label_encoder.inverse_transform([predicted.item()])[0]
confidence_score = confidence.item()
# Security checks
if confidence_score < confidence_threshold:
return (
f"❌ Access Denied - Low Confidence",
predicted_user,
confidence_score,
f"Confidence {confidence_score:.3f} below threshold {confidence_threshold}"
)
if predicted_user not in authorized_users:
return (
f"❌ Access Denied - Unauthorized User",
predicted_user,
confidence_score,
f"User '{predicted_user}' not in authorized list"
)
return (
f"βœ… Access Granted",
predicted_user,
confidence_score,
f"Welcome {predicted_user}! High confidence recognition."
)
except Exception as e:
return f"❌ Error processing audio", "Error", 0.0, f"Error: {str(e)}"
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Voice Recognition Security System", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎀 Voice Recognition Security System
This system uses advanced voice recognition to control access. Upload an audio file to test the system.
**Authorized Users:** user1, user2, user3, user4, user5, user6, user7
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Audio File",
type="filepath",
sources=["upload", "microphone"]
)
confidence_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Confidence Threshold"
)
predict_btn = gr.Button("πŸ” Analyze Voice", variant="primary")
with gr.Column():
access_result = gr.Textbox(
label="Access Decision",
placeholder="Upload audio to see result...",
lines=2
)
predicted_user = gr.Textbox(
label="Predicted User",
placeholder="No prediction yet..."
)
confidence_score = gr.Number(
label="Confidence Score",
precision=3
)
details = gr.Textbox(
label="Details",
placeholder="Additional information will appear here...",
lines=3
)
# Examples section
gr.Markdown("### πŸ“‹ Instructions")
gr.Markdown(
"""
1. **Upload Audio**: Click on the audio component to upload a .wav, .mp3, or other audio file
2. **Record Audio**: Use the microphone button to record directly
3. **Set Threshold**: Adjust the confidence threshold (higher = more strict)
4. **Analyze**: Click 'Analyze Voice' to process the audio
The system will determine if the speaker is authorized and grant/deny access accordingly.
"""
)
# Connect the interface
predict_btn.click(
fn=predict_voice,
inputs=[audio_input, confidence_slider],
outputs=[access_result, predicted_user, confidence_score, details]
)
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch()