|
|
import gradio as gr |
|
|
import torch |
|
|
import librosa |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
|
|
|
|
|
|
class TransferLearningModel(nn.Module): |
|
|
def __init__(self, num_classes): |
|
|
super(TransferLearningModel, self).__init__() |
|
|
self.resnet = models.resnet18(weights=None) |
|
|
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
num_ftrs = self.resnet.fc.in_features |
|
|
self.resnet.fc = nn.Linear(num_ftrs, num_classes) |
|
|
self.dropout = nn.Dropout(0.5) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.resnet.conv1(x) |
|
|
x = self.resnet.bn1(x) |
|
|
x = self.resnet.relu(x) |
|
|
x = self.resnet.maxpool(x) |
|
|
x = self.resnet.layer1(x) |
|
|
x = self.resnet.layer2(x) |
|
|
x = self.resnet.layer3(x) |
|
|
x = self.resnet.layer4(x) |
|
|
x = self.resnet.avgpool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
x = self.dropout(x) |
|
|
x = self.resnet.fc(x) |
|
|
return x |
|
|
|
|
|
|
|
|
LABELS = ['unknown', 'user1', 'user2', 'user3', 'user4', 'user5', 'user6', 'user7'] |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = TransferLearningModel(num_classes=len(LABELS)) |
|
|
model_path = "voice_recognition_final_enhanced.pth" |
|
|
|
|
|
try: |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to load model: {e}") |
|
|
model = None |
|
|
|
|
|
|
|
|
def extract_features_from_file(file_path, max_pad_len=174): |
|
|
try: |
|
|
audio, sample_rate = librosa.load(file_path, sr=None, res_type='kaiser_fast') |
|
|
mfccs = librosa.feature.mfcc(y=audio, sr=sample_rate, n_mfcc=40) |
|
|
pad_width = max_pad_len - mfccs.shape[1] |
|
|
if pad_width > 0: |
|
|
mfccs = np.pad(mfccs, pad_width=((0, 0), (0, pad_width)), mode='constant') |
|
|
else: |
|
|
mfccs = mfccs[:, :max_pad_len] |
|
|
return mfccs |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Feature extraction failed: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def predict(file): |
|
|
if model is None: |
|
|
return "Error: Model not loaded." |
|
|
|
|
|
features = extract_features_from_file(file.name) |
|
|
if features is None: |
|
|
return "Error: Could not extract features from audio." |
|
|
|
|
|
input_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
|
|
predicted_user = LABELS[predicted.item()] |
|
|
confidence_score = confidence.item() |
|
|
|
|
|
if confidence_score < 0.7 or predicted_user == 'unknown': |
|
|
return f"β Unknown user or low confidence (Confidence: {confidence_score:.3f})" |
|
|
else: |
|
|
return f"β
Access granted to {predicted_user} (Confidence: {confidence_score:.3f})" |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Audio(type="filepath"), |
|
|
outputs="text", |
|
|
title="ποΈ Voice Recognition", |
|
|
description=( |
|
|
"Upload an audio file (wav, mp3, etc.) containing a user's voice. " |
|
|
"The model will analyze the audio and predict the user identity if recognized. " |
|
|
"If the user is unknown or the confidence is low, access will be denied. " |
|
|
"This system supports 7 authorized users and detects unknown users for security." |
|
|
), |
|
|
article=( |
|
|
"### How to Use\n" |
|
|
"1. Click the 'Browse' button to upload an audio file.\n" |
|
|
"2. Wait for the model to process and display the prediction result.\n" |
|
|
"3. The output will show the predicted user and confidence score.\n" |
|
|
"4. If the user is unknown or confidence is below threshold, access will be denied.\n\n" |
|
|
"### Supported Users\n" |
|
|
"- user1, user2, user3, user4, user5, user6, user7\n\n" |
|
|
"### Notes\n" |
|
|
"- Ensure audio quality is good for best results.\n" |
|
|
"- Supported audio formats include wav, mp3, flac, ogg, m4a, aac.\n" |
|
|
"- The model uses MFCC features and a ResNet18-based CNN architecture.\n" |
|
|
"- For questions or issues, please refer to the project README or contact support." |
|
|
) |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|
|
|
|