Upload 5 files
Browse files- README.md +90 -3
- config.json +38 -0
- inference_example.py +84 -0
- model_architecture.py +67 -0
- requirements.txt +4 -0
README.md
CHANGED
|
@@ -1,3 +1,90 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mouse USV Detector - ProSAP1/Shank2 Male-Oestrus Female Interactions
|
| 2 |
+
|
| 3 |
+
Deep learning model for detecting ultrasonic vocalizations (USVs) in mouse recordings from male-oestrus female social interactions.
|
| 4 |
+
|
| 5 |
+
## Model Details
|
| 6 |
+
|
| 7 |
+
- **Model Type**: Convolutional Neural Network (CNN)
|
| 8 |
+
- **Task**: Binary classification (USV vs. noise)
|
| 9 |
+
- **Architecture**: 4-layer CNN with batch normalization and dropout
|
| 10 |
+
- **Parameters**: 10.4M
|
| 11 |
+
- **Framework**: PyTorch 2.0+
|
| 12 |
+
|
| 13 |
+
## Performance
|
| 14 |
+
|
| 15 |
+
- **Validation Accuracy**: 96.0%
|
| 16 |
+
- **Noise Detection**: 94.7%
|
| 17 |
+
- **USV Detection**: 98.0%
|
| 18 |
+
|
| 19 |
+
## Training Protocol
|
| 20 |
+
|
| 21 |
+
- **Subject**: S2-4-65
|
| 22 |
+
- **Strain**: ProSAP1/Shank2
|
| 23 |
+
- **Behavior**: Male-oestrus female interactions (10 min + 3 min)
|
| 24 |
+
- **Dataset**: 7,188 training samples, 2,146 validation samples
|
| 25 |
+
- **Epochs**: 10
|
| 26 |
+
|
| 27 |
+
## Audio Specifications
|
| 28 |
+
|
| 29 |
+
- **Sample Rate**: 250 kHz (ultrasonic)
|
| 30 |
+
- **USV Frequency Range**: 40-100 kHz
|
| 31 |
+
- **Input Format**: 64x64 spectrogram patches
|
| 32 |
+
|
| 33 |
+
## Usage
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
import torch
|
| 37 |
+
from model_architecture import load_model
|
| 38 |
+
from inference_example import predict
|
| 39 |
+
|
| 40 |
+
# Load model
|
| 41 |
+
model = load_model('final_usv_model.pth')
|
| 42 |
+
|
| 43 |
+
# Predict on audio file
|
| 44 |
+
result = predict('audio.wav', model)
|
| 45 |
+
print(f"USV: {result['is_usv']}, Confidence: {result['confidence']:.2%}")
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Requirements
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
pip install torch numpy librosa scipy
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Files
|
| 55 |
+
|
| 56 |
+
- `final_usv_model.pth` - Trained model weights (41.9 MB)
|
| 57 |
+
- `model_architecture.py` - CNN architecture definition
|
| 58 |
+
- `inference_example.py` - Example inference code
|
| 59 |
+
- `config.json` - Model configuration and metadata
|
| 60 |
+
- `requirements.txt` - Python dependencies
|
| 61 |
+
|
| 62 |
+
## Citation
|
| 63 |
+
|
| 64 |
+
If you use this model, please cite:
|
| 65 |
+
|
| 66 |
+
```bibtex
|
| 67 |
+
@misc{usv_detector_prosap1_shank2,
|
| 68 |
+
title={Mouse USV Detector for ProSAP1/Shank2 Social Interactions},
|
| 69 |
+
author={Your Name},
|
| 70 |
+
year={2025},
|
| 71 |
+
publisher={Hugging Face},
|
| 72 |
+
howpublished={\url{https://huggingface.co/your-username/model-name}}
|
| 73 |
+
}
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## License
|
| 77 |
+
|
| 78 |
+
[Specify your license here]
|
| 79 |
+
|
| 80 |
+
## Methodology
|
| 81 |
+
|
| 82 |
+
Based on the DeepSqueak methodology for USV detection:
|
| 83 |
+
- Spectrogram-based feature extraction
|
| 84 |
+
- Tonality calculation for USV identification
|
| 85 |
+
- Automated detection with manual validation
|
| 86 |
+
- Deep learning classification for robust detection
|
| 87 |
+
|
| 88 |
+
## Contact
|
| 89 |
+
|
| 90 |
+
[Your contact information]
|
config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "Mouse_USV_Detector_ProSAP1_Shank2",
|
| 3 |
+
"model_type": "USVDetectorCNN",
|
| 4 |
+
"task": "audio-classification",
|
| 5 |
+
"protocol": "S2-4-65 ProSAP1/Shank2 Male-Oestrus Female Interactions",
|
| 6 |
+
"architecture": {
|
| 7 |
+
"input_size": [64, 64],
|
| 8 |
+
"num_classes": 2,
|
| 9 |
+
"classes": ["noise", "usv"]
|
| 10 |
+
},
|
| 11 |
+
"audio_preprocessing": {
|
| 12 |
+
"sample_rate": 250000,
|
| 13 |
+
"nfft": 0.0032,
|
| 14 |
+
"overlap": 0.0028,
|
| 15 |
+
"hop_length_samples": 800,
|
| 16 |
+
"freq_range": [40000, 100000],
|
| 17 |
+
"freq_range_description": "Mouse USV frequency range (40-100 kHz)"
|
| 18 |
+
},
|
| 19 |
+
"training": {
|
| 20 |
+
"dataset_size": 9334,
|
| 21 |
+
"train_samples": 7188,
|
| 22 |
+
"val_samples": 2146,
|
| 23 |
+
"num_epochs": 10,
|
| 24 |
+
"batch_size": 32,
|
| 25 |
+
"learning_rate": 0.001,
|
| 26 |
+
"optimizer": "Adam",
|
| 27 |
+
"loss_function": "CrossEntropyLoss"
|
| 28 |
+
},
|
| 29 |
+
"performance": {
|
| 30 |
+
"validation_accuracy": 96.0,
|
| 31 |
+
"noise_accuracy": 94.7,
|
| 32 |
+
"usv_accuracy": 98.0
|
| 33 |
+
},
|
| 34 |
+
"framework": "pytorch",
|
| 35 |
+
"pytorch_version": "2.0+",
|
| 36 |
+
"model_size_mb": 41.9,
|
| 37 |
+
"parameters": 10467202
|
| 38 |
+
}
|
inference_example.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
from scipy.ndimage import zoom
|
| 5 |
+
from model_architecture import load_model
|
| 6 |
+
|
| 7 |
+
def preprocess_audio(audio_path, sr=250000, nfft=0.0032, overlap=0.0028,
|
| 8 |
+
freq_range=(40000, 100000)):
|
| 9 |
+
'''
|
| 10 |
+
Preprocess audio file into spectrogram patch
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
audio_path: Path to .wav file
|
| 14 |
+
sr: Sample rate (250 kHz for ultrasonic)
|
| 15 |
+
nfft: FFT window size in seconds
|
| 16 |
+
overlap: Overlap between windows in seconds
|
| 17 |
+
freq_range: Frequency range to extract (Hz)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
torch.Tensor: Preprocessed spectrogram (1, 1, 64, 64)
|
| 21 |
+
'''
|
| 22 |
+
# Load audio
|
| 23 |
+
audio, _ = librosa.load(audio_path, sr=sr)
|
| 24 |
+
|
| 25 |
+
# Generate spectrogram
|
| 26 |
+
nfft_samples = int(nfft * sr)
|
| 27 |
+
hop_length = int((nfft - overlap) * sr)
|
| 28 |
+
spec = librosa.stft(audio, n_fft=nfft_samples, hop_length=hop_length)
|
| 29 |
+
spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max)
|
| 30 |
+
|
| 31 |
+
# Filter to USV frequency range
|
| 32 |
+
freqs = librosa.fft_frequencies(sr=sr, n_fft=nfft_samples)
|
| 33 |
+
freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
|
| 34 |
+
spec_db = spec_db[freq_mask, :]
|
| 35 |
+
|
| 36 |
+
# Resize to 64x64
|
| 37 |
+
zoom_factors = (64 / spec_db.shape[0], 64 / spec_db.shape[1])
|
| 38 |
+
spec_resized = zoom(spec_db, zoom_factors, order=1)
|
| 39 |
+
|
| 40 |
+
# Normalize
|
| 41 |
+
spec_resized = (spec_resized - np.mean(spec_resized)) / (np.std(spec_resized) + 1e-8)
|
| 42 |
+
|
| 43 |
+
# Convert to tensor
|
| 44 |
+
return torch.FloatTensor(spec_resized).unsqueeze(0).unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
def predict(audio_path, model):
|
| 47 |
+
'''
|
| 48 |
+
Predict if audio contains USV
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
audio_path: Path to .wav file
|
| 52 |
+
model: Loaded USVDetectorCNN model
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
dict: Prediction results
|
| 56 |
+
'''
|
| 57 |
+
# Preprocess
|
| 58 |
+
spec_tensor = preprocess_audio(audio_path)
|
| 59 |
+
|
| 60 |
+
# Predict
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
output = model(spec_tensor)
|
| 63 |
+
probabilities = torch.softmax(output, dim=1)
|
| 64 |
+
prediction = torch.argmax(output, dim=1).item()
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
'is_usv': prediction == 1,
|
| 68 |
+
'confidence': probabilities[0][prediction].item(),
|
| 69 |
+
'usv_probability': probabilities[0][1].item(),
|
| 70 |
+
'noise_probability': probabilities[0][0].item()
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
# Example usage
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
# Load model
|
| 76 |
+
model = load_model('final_usv_model.pth')
|
| 77 |
+
|
| 78 |
+
# Predict
|
| 79 |
+
result = predict('test_audio.wav', model)
|
| 80 |
+
|
| 81 |
+
print(f"USV Detected: {result['is_usv']}")
|
| 82 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
| 83 |
+
print(f"USV Probability: {result['usv_probability']:.2%}")
|
| 84 |
+
print(f"Noise Probability: {result['noise_probability']:.2%}")
|
model_architecture.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class USVDetectorCNN(nn.Module):
|
| 5 |
+
'''
|
| 6 |
+
CNN for Mouse Ultrasonic Vocalization (USV) Detection
|
| 7 |
+
|
| 8 |
+
Trained on ProSAP1/Shank2 male-oestrus female interaction recordings.
|
| 9 |
+
Classifies spectrogram patches as USV or noise.
|
| 10 |
+
|
| 11 |
+
Input: (batch_size, 1, 64, 64) - Spectrogram patch in 40-100 kHz range
|
| 12 |
+
Output: (batch_size, 2) - Logits for [noise, usv]
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
def __init__(self, input_size=(64, 64), num_classes=2):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.features = nn.Sequential(
|
| 19 |
+
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
| 20 |
+
nn.BatchNorm2d(64),
|
| 21 |
+
nn.ReLU(inplace=True),
|
| 22 |
+
nn.MaxPool2d(2, 2),
|
| 23 |
+
nn.Dropout2d(0.2),
|
| 24 |
+
|
| 25 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 26 |
+
nn.BatchNorm2d(128),
|
| 27 |
+
nn.ReLU(inplace=True),
|
| 28 |
+
nn.MaxPool2d(2, 2),
|
| 29 |
+
nn.Dropout2d(0.2),
|
| 30 |
+
|
| 31 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
| 32 |
+
nn.BatchNorm2d(256),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
nn.MaxPool2d(2, 2),
|
| 35 |
+
nn.Dropout2d(0.3),
|
| 36 |
+
|
| 37 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
| 38 |
+
nn.BatchNorm2d(512),
|
| 39 |
+
nn.ReLU(inplace=True),
|
| 40 |
+
nn.MaxPool2d(2, 2),
|
| 41 |
+
nn.Dropout2d(0.3),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
flat_size = 512 * (input_size[0] // 16) * (input_size[1] // 16)
|
| 45 |
+
|
| 46 |
+
self.classifier = nn.Sequential(
|
| 47 |
+
nn.Linear(flat_size, 1024),
|
| 48 |
+
nn.ReLU(inplace=True),
|
| 49 |
+
nn.Dropout(0.5),
|
| 50 |
+
nn.Linear(1024, 512),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
nn.Dropout(0.5),
|
| 53 |
+
nn.Linear(512, num_classes)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x = self.features(x)
|
| 58 |
+
x = x.view(x.size(0), -1)
|
| 59 |
+
x = self.classifier(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
def load_model(model_path='final_usv_model.pth', device='cpu'):
|
| 63 |
+
'''Load the trained model'''
|
| 64 |
+
model = USVDetectorCNN(input_size=(64, 64), num_classes=2)
|
| 65 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 66 |
+
model.eval()
|
| 67 |
+
return model
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
librosa>=0.10.0
|
| 4 |
+
scipy>=1.10.0
|