RaphaelSchwinger commited on
Commit
27704b5
·
verified ·
1 Parent(s): 83339c7

Update with example code to run the model.

Browse files
Files changed (1) hide show
  1. README.md +111 -3
README.md CHANGED
@@ -24,12 +24,120 @@ The model accepts a mono image (spectrogram) as input (e.g., `torch.Size([16, 1,
24
  - melscale: n_mels: 128, n_stft: 513
25
  - dbscale: top_db: 80
26
 
 
 
27
  ```python
 
28
  import torch
29
- from transformers import AutoModelForImageClassification
30
- from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- dataset = load_dataset("DBD-research-group/BirdSet", "HSN")
33
  ```
34
 
35
  ## Model Source
 
24
  - melscale: n_mels: 128, n_stft: 513
25
  - dbscale: top_db: 80
26
 
27
+ See [example inference notebook](https://github.com/DBD-research-group/BirdSet/blob/main/notebooks/tutorials/model_inference.ipynb):
28
+
29
  ```python
30
+ from transformers import ConvNextForImageClassification
31
  import torch
32
+ import torchaudio
33
+ from torchvision import transforms
34
+ import requests
35
+ import torchaudio
36
+ import io
37
+
38
+
39
+ # download the audio file of a bird sound: Common Craw
40
+ url = "https://xeno-canto.org/704485/download"
41
+ response = requests.get(url)
42
+ audio, sample_rate = torchaudio.load(io.BytesIO(response.content))
43
+ print("Original shape and sample rate: ", audio.shape, sample_rate)
44
+ # crop to 5 seconds
45
+ audio = audio[:, : 5 * sample_rate]
46
+ # resample to 32kHz
47
+ resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
48
+ audio = resample(audio)
49
+ print("Resampled shape and sample rate: ", audio.shape, 32000)
50
+
51
+
52
+ CACHE_DIR = "../../data_birdset" # Change this to your own cache directory
53
+
54
+ # Load the model
55
+ model = ConvNextForImageClassification.from_pretrained(
56
+ "DBD-research-group/ConvNeXT-Base-BirdSet-XCL",
57
+ cache_dir=CACHE_DIR,
58
+ ignore_mismatched_sizes=True,
59
+ )
60
+
61
+
62
+ class PowerToDB(torch.nn.Module):
63
+ """
64
+ A power spectrogram to decibel conversion layer. See birdset.datamodule.components.augmentations
65
+ """
66
+
67
+ def __init__(self, ref=1.0, amin=1e-10, top_db=80.0):
68
+ super(PowerToDB, self).__init__()
69
+ # Initialize parameters
70
+ self.ref = ref
71
+ self.amin = amin
72
+ self.top_db = top_db
73
+
74
+ def forward(self, S):
75
+ # Convert S to a PyTorch tensor if it is not already
76
+ S = torch.as_tensor(S, dtype=torch.float32)
77
+
78
+ if self.amin <= 0:
79
+ raise ValueError("amin must be strictly positive")
80
+
81
+ if torch.is_complex(S):
82
+ magnitude = S.abs()
83
+ else:
84
+ magnitude = S
85
+
86
+ # Check if ref is a callable function or a scalar
87
+ if callable(self.ref):
88
+ ref_value = self.ref(magnitude)
89
+ else:
90
+ ref_value = torch.abs(torch.tensor(self.ref, dtype=S.dtype))
91
+
92
+ # Compute the log spectrogram
93
+ log_spec = 10.0 * torch.log10(
94
+ torch.maximum(magnitude, torch.tensor(self.amin, device=magnitude.device))
95
+ )
96
+ log_spec -= 10.0 * torch.log10(
97
+ torch.maximum(ref_value, torch.tensor(self.amin, device=magnitude.device))
98
+ )
99
+
100
+ # Apply top_db threshold if necessary
101
+ if self.top_db is not None:
102
+ if self.top_db < 0:
103
+ raise ValueError("top_db must be non-negative")
104
+ log_spec = torch.maximum(log_spec, log_spec.max() - self.top_db)
105
+
106
+ return log_spec
107
+
108
+
109
+ def preprocess(audio, sample_rate_of_audio):
110
+ """
111
+ Preprocess the audio to the format that the model expects
112
+ - Resample to 32kHz
113
+ - Convert to melscale spectrogram n_fft: 1024, hop_length: 320, power: 2. melscale: n_mels: 128, n_stft: 513
114
+ - Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
115
+
116
+ """
117
+ powerToDB = PowerToDB()
118
+ # Resample to 32kHz
119
+ resample = torchaudio.transforms.Resample(
120
+ orig_freq=sample_rate_of_audio, new_freq=32000
121
+ )
122
+ audio = resample(audio)
123
+ spectrogram = torchaudio.transforms.Spectrogram(
124
+ n_fft=1024, hop_length=320, power=2.0
125
+ )(audio)
126
+ melspec = torchaudio.transforms.MelScale(n_mels=128, n_stft=513)(spectrogram)
127
+ dbscale = powerToDB(melspec)
128
+ normalized_dbscale = transforms.Normalize((-4.268,), (4.569,))(dbscale)
129
+ return normalized_dbscale
130
+
131
+ preprocessed_audio = preprocess(audio, sample_rate)
132
+
133
+ logits = model(preprocessed_audio.unsqueeze(0)).logits
134
+ print("Logits shape: ", logits.shape)
135
+
136
+ top5 = torch.topk(logits, 5)
137
+ print("Top 5 logits:", top5.values)
138
+ print("Top 5 predicted classes:")
139
+ print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
140
 
 
141
  ```
142
 
143
  ## Model Source