czyoung commited on
Commit
ddefe81
·
verified ·
1 Parent(s): 85b7ab1

Update sonogram_utility.py

Browse files
Files changed (1) hide show
  1. sonogram_utility.py +72 -1
sonogram_utility.py CHANGED
@@ -3,6 +3,8 @@ import random
3
  import copy
4
  from pyannote.core import Annotation, Segment
5
  import numpy as np
 
 
6
 
7
  def colors(n):
8
  '''
@@ -94,4 +96,73 @@ def loadAudioRTTM(sampleRTTM):
94
  speakerList[index].append((float(speakerResult[3]),float(speakerResult[4])))
95
  prediction[Segment(start,end)] = index
96
 
97
- return speakerList, prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import copy
4
  from pyannote.core import Annotation, Segment
5
  import numpy as np
6
+ import torch
7
+ import torchaudio
8
 
9
  def colors(n):
10
  '''
 
96
  speakerList[index].append((float(speakerResult[3]),float(speakerResult[4])))
97
  prediction[Segment(start,end)] = index
98
 
99
+ return speakerList, prediction
100
+
101
+ def splitIntoTimeSegments(testFile,maxDurationInSeconds=60)
102
+
103
+ waveform, sample_rate = torchaudio.load(testFile)
104
+ audioSegments = []
105
+
106
+ outOfBoundsIndex = waveform.shape[-1]
107
+ currentStart = 0
108
+ currentEnd = min(maxDurationInSeconds * sample_rate,outOfBoundsIndex)
109
+ done = False
110
+ while(not done):
111
+ waveformSegment = waveform[:,currentStart:currentEnd]
112
+ audioSegments.append(waveformSegment)
113
+ if currentEnd >= outOfBoundsIndex:
114
+ done = True
115
+ break
116
+ else:
117
+ currentStart = currentEnd
118
+ currentEnd = min(currentStart + maxDurationInSeconds * sample_rate,outOfBoundsIndex)
119
+ return audioSegments, sample_rate
120
+
121
+ def audioNormalize(waveform,sampleRate,stepSizeInSeconds = 2,dbThreshold = -50,dbTarget = -5):
122
+ copyWaveform = waveform.clone().detach()
123
+ copyWaveform_db = waveform.clone().detach()
124
+ transform = torchaudio.transforms.AmplitudeToDB(stype="amplitude", top_db=80)
125
+ copyWaveform_db = transform(copyWaveform_db)
126
+ currStart = 0
127
+ currEnd = int(min(currStart + stepSizeInSeconds * sampleRate, len(copyWaveform_db[0])-1))
128
+ done = False
129
+ while(not done):
130
+ if torch.max(copyWaveform_db[0][currStart:currEnd]).item() > dbThreshold:
131
+ gain = torch.min(dbTarget - copyWaveform_db[0][currStart:currEnd])
132
+ adjustGain = torchaudio.transforms.Vol(gain,'db')
133
+ copyWaveform[0][currStart:currEnd] = adjustGain(copyWaveform[0][currStart:currEnd])
134
+ if len(copyWaveform_db) > 1:
135
+ if torch.max(copyWaveform_db[1][currStart:currEnd]).item() > dbThreshold:
136
+ gain = torch.min(dbTarget - copyWaveform_db[1][currStart:currEnd])
137
+ adjustGain = torchaudio.transforms.Vol(gain,'db')
138
+ copyWaveform[1][currStart:currEnd] = adjustGain(copyWaveform[1][currStart:currEnd])
139
+ currStart += int(stepSizeInSeconds * sampleRate)
140
+ if currStart > currEnd:
141
+ done = True
142
+ else:
143
+ currEnd = int(min(currStart + stepSizeInSeconds * sampleRate, len(copyWaveform_db[0])-1))
144
+ return copyWaveform
145
+
146
+ class equalizeVolume(torch.nn.Module):
147
+ def forward(self, waveform,sampleRate,stepSizeInSeconds,dbThreshold,dbTarget):
148
+ waveformDifference = audioNormalize(waveform,sampleRate,stepSizeInSeconds,dbThreshold,dbTarget)
149
+ return waveformDifference
150
+
151
+ def combineWaveforms(waveformList):
152
+ return torch.cat(waveformList,1)
153
+
154
+ def annotationToSpeakerList(myAnnotation):
155
+ tempSpeakerList = []
156
+ tempSpeakerNames = []
157
+ for speakerName in myAnnotation.labels():
158
+ speakerIndex = None
159
+ if speakerName not in tempSpeakerNames:
160
+ speakerIndex = len(tempSpeakerNames)
161
+ tempSpeakerNames.append(speakerName)
162
+ tempSpeakerList.append([])
163
+ else:
164
+ speakerIndex = tempSpeakerNames.index(speakerName)
165
+
166
+ for segmentItem in myAnnotation.label_support(speakerName):
167
+ tempSpeakerList[speakerIndex].append((segmentItem.start,segmentItem.duration))
168
+ return tempSpeakerList