sulstice2 commited on
Commit
8d28a33
·
verified ·
1 Parent(s): 9b8fd2b

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +90 -3
  2. config.json +38 -0
  3. inference_example.py +84 -0
  4. model_architecture.py +67 -0
  5. requirements.txt +4 -0
README.md CHANGED
@@ -1,3 +1,90 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mouse USV Detector - ProSAP1/Shank2 Male-Oestrus Female Interactions
2
+
3
+ Deep learning model for detecting ultrasonic vocalizations (USVs) in mouse recordings from male-oestrus female social interactions.
4
+
5
+ ## Model Details
6
+
7
+ - **Model Type**: Convolutional Neural Network (CNN)
8
+ - **Task**: Binary classification (USV vs. noise)
9
+ - **Architecture**: 4-layer CNN with batch normalization and dropout
10
+ - **Parameters**: 10.4M
11
+ - **Framework**: PyTorch 2.0+
12
+
13
+ ## Performance
14
+
15
+ - **Validation Accuracy**: 96.0%
16
+ - **Noise Detection**: 94.7%
17
+ - **USV Detection**: 98.0%
18
+
19
+ ## Training Protocol
20
+
21
+ - **Subject**: S2-4-65
22
+ - **Strain**: ProSAP1/Shank2
23
+ - **Behavior**: Male-oestrus female interactions (10 min + 3 min)
24
+ - **Dataset**: 7,188 training samples, 2,146 validation samples
25
+ - **Epochs**: 10
26
+
27
+ ## Audio Specifications
28
+
29
+ - **Sample Rate**: 250 kHz (ultrasonic)
30
+ - **USV Frequency Range**: 40-100 kHz
31
+ - **Input Format**: 64x64 spectrogram patches
32
+
33
+ ## Usage
34
+
35
+ ```python
36
+ import torch
37
+ from model_architecture import load_model
38
+ from inference_example import predict
39
+
40
+ # Load model
41
+ model = load_model('final_usv_model.pth')
42
+
43
+ # Predict on audio file
44
+ result = predict('audio.wav', model)
45
+ print(f"USV: {result['is_usv']}, Confidence: {result['confidence']:.2%}")
46
+ ```
47
+
48
+ ## Requirements
49
+
50
+ ```bash
51
+ pip install torch numpy librosa scipy
52
+ ```
53
+
54
+ ## Files
55
+
56
+ - `final_usv_model.pth` - Trained model weights (41.9 MB)
57
+ - `model_architecture.py` - CNN architecture definition
58
+ - `inference_example.py` - Example inference code
59
+ - `config.json` - Model configuration and metadata
60
+ - `requirements.txt` - Python dependencies
61
+
62
+ ## Citation
63
+
64
+ If you use this model, please cite:
65
+
66
+ ```bibtex
67
+ @misc{usv_detector_prosap1_shank2,
68
+ title={Mouse USV Detector for ProSAP1/Shank2 Social Interactions},
69
+ author={Your Name},
70
+ year={2025},
71
+ publisher={Hugging Face},
72
+ howpublished={\url{https://huggingface.co/your-username/model-name}}
73
+ }
74
+ ```
75
+
76
+ ## License
77
+
78
+ [Specify your license here]
79
+
80
+ ## Methodology
81
+
82
+ Based on the DeepSqueak methodology for USV detection:
83
+ - Spectrogram-based feature extraction
84
+ - Tonality calculation for USV identification
85
+ - Automated detection with manual validation
86
+ - Deep learning classification for robust detection
87
+
88
+ ## Contact
89
+
90
+ [Your contact information]
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "Mouse_USV_Detector_ProSAP1_Shank2",
3
+ "model_type": "USVDetectorCNN",
4
+ "task": "audio-classification",
5
+ "protocol": "S2-4-65 ProSAP1/Shank2 Male-Oestrus Female Interactions",
6
+ "architecture": {
7
+ "input_size": [64, 64],
8
+ "num_classes": 2,
9
+ "classes": ["noise", "usv"]
10
+ },
11
+ "audio_preprocessing": {
12
+ "sample_rate": 250000,
13
+ "nfft": 0.0032,
14
+ "overlap": 0.0028,
15
+ "hop_length_samples": 800,
16
+ "freq_range": [40000, 100000],
17
+ "freq_range_description": "Mouse USV frequency range (40-100 kHz)"
18
+ },
19
+ "training": {
20
+ "dataset_size": 9334,
21
+ "train_samples": 7188,
22
+ "val_samples": 2146,
23
+ "num_epochs": 10,
24
+ "batch_size": 32,
25
+ "learning_rate": 0.001,
26
+ "optimizer": "Adam",
27
+ "loss_function": "CrossEntropyLoss"
28
+ },
29
+ "performance": {
30
+ "validation_accuracy": 96.0,
31
+ "noise_accuracy": 94.7,
32
+ "usv_accuracy": 98.0
33
+ },
34
+ "framework": "pytorch",
35
+ "pytorch_version": "2.0+",
36
+ "model_size_mb": 41.9,
37
+ "parameters": 10467202
38
+ }
inference_example.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa
4
+ from scipy.ndimage import zoom
5
+ from model_architecture import load_model
6
+
7
+ def preprocess_audio(audio_path, sr=250000, nfft=0.0032, overlap=0.0028,
8
+ freq_range=(40000, 100000)):
9
+ '''
10
+ Preprocess audio file into spectrogram patch
11
+
12
+ Args:
13
+ audio_path: Path to .wav file
14
+ sr: Sample rate (250 kHz for ultrasonic)
15
+ nfft: FFT window size in seconds
16
+ overlap: Overlap between windows in seconds
17
+ freq_range: Frequency range to extract (Hz)
18
+
19
+ Returns:
20
+ torch.Tensor: Preprocessed spectrogram (1, 1, 64, 64)
21
+ '''
22
+ # Load audio
23
+ audio, _ = librosa.load(audio_path, sr=sr)
24
+
25
+ # Generate spectrogram
26
+ nfft_samples = int(nfft * sr)
27
+ hop_length = int((nfft - overlap) * sr)
28
+ spec = librosa.stft(audio, n_fft=nfft_samples, hop_length=hop_length)
29
+ spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max)
30
+
31
+ # Filter to USV frequency range
32
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=nfft_samples)
33
+ freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
34
+ spec_db = spec_db[freq_mask, :]
35
+
36
+ # Resize to 64x64
37
+ zoom_factors = (64 / spec_db.shape[0], 64 / spec_db.shape[1])
38
+ spec_resized = zoom(spec_db, zoom_factors, order=1)
39
+
40
+ # Normalize
41
+ spec_resized = (spec_resized - np.mean(spec_resized)) / (np.std(spec_resized) + 1e-8)
42
+
43
+ # Convert to tensor
44
+ return torch.FloatTensor(spec_resized).unsqueeze(0).unsqueeze(0)
45
+
46
+ def predict(audio_path, model):
47
+ '''
48
+ Predict if audio contains USV
49
+
50
+ Args:
51
+ audio_path: Path to .wav file
52
+ model: Loaded USVDetectorCNN model
53
+
54
+ Returns:
55
+ dict: Prediction results
56
+ '''
57
+ # Preprocess
58
+ spec_tensor = preprocess_audio(audio_path)
59
+
60
+ # Predict
61
+ with torch.no_grad():
62
+ output = model(spec_tensor)
63
+ probabilities = torch.softmax(output, dim=1)
64
+ prediction = torch.argmax(output, dim=1).item()
65
+
66
+ return {
67
+ 'is_usv': prediction == 1,
68
+ 'confidence': probabilities[0][prediction].item(),
69
+ 'usv_probability': probabilities[0][1].item(),
70
+ 'noise_probability': probabilities[0][0].item()
71
+ }
72
+
73
+ # Example usage
74
+ if __name__ == "__main__":
75
+ # Load model
76
+ model = load_model('final_usv_model.pth')
77
+
78
+ # Predict
79
+ result = predict('test_audio.wav', model)
80
+
81
+ print(f"USV Detected: {result['is_usv']}")
82
+ print(f"Confidence: {result['confidence']:.2%}")
83
+ print(f"USV Probability: {result['usv_probability']:.2%}")
84
+ print(f"Noise Probability: {result['noise_probability']:.2%}")
model_architecture.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class USVDetectorCNN(nn.Module):
5
+ '''
6
+ CNN for Mouse Ultrasonic Vocalization (USV) Detection
7
+
8
+ Trained on ProSAP1/Shank2 male-oestrus female interaction recordings.
9
+ Classifies spectrogram patches as USV or noise.
10
+
11
+ Input: (batch_size, 1, 64, 64) - Spectrogram patch in 40-100 kHz range
12
+ Output: (batch_size, 2) - Logits for [noise, usv]
13
+ '''
14
+
15
+ def __init__(self, input_size=(64, 64), num_classes=2):
16
+ super().__init__()
17
+
18
+ self.features = nn.Sequential(
19
+ nn.Conv2d(1, 64, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(64),
21
+ nn.ReLU(inplace=True),
22
+ nn.MaxPool2d(2, 2),
23
+ nn.Dropout2d(0.2),
24
+
25
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
26
+ nn.BatchNorm2d(128),
27
+ nn.ReLU(inplace=True),
28
+ nn.MaxPool2d(2, 2),
29
+ nn.Dropout2d(0.2),
30
+
31
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
32
+ nn.BatchNorm2d(256),
33
+ nn.ReLU(inplace=True),
34
+ nn.MaxPool2d(2, 2),
35
+ nn.Dropout2d(0.3),
36
+
37
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
38
+ nn.BatchNorm2d(512),
39
+ nn.ReLU(inplace=True),
40
+ nn.MaxPool2d(2, 2),
41
+ nn.Dropout2d(0.3),
42
+ )
43
+
44
+ flat_size = 512 * (input_size[0] // 16) * (input_size[1] // 16)
45
+
46
+ self.classifier = nn.Sequential(
47
+ nn.Linear(flat_size, 1024),
48
+ nn.ReLU(inplace=True),
49
+ nn.Dropout(0.5),
50
+ nn.Linear(1024, 512),
51
+ nn.ReLU(inplace=True),
52
+ nn.Dropout(0.5),
53
+ nn.Linear(512, num_classes)
54
+ )
55
+
56
+ def forward(self, x):
57
+ x = self.features(x)
58
+ x = x.view(x.size(0), -1)
59
+ x = self.classifier(x)
60
+ return x
61
+
62
+ def load_model(model_path='final_usv_model.pth', device='cpu'):
63
+ '''Load the trained model'''
64
+ model = USVDetectorCNN(input_size=(64, 64), num_classes=2)
65
+ model.load_state_dict(torch.load(model_path, map_location=device))
66
+ model.eval()
67
+ return model
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.24.0
3
+ librosa>=0.10.0
4
+ scipy>=1.10.0