VoidFilter / utils.py
MahmoudElsamadony's picture
Update utils.py
23feff5
import numpy as np
def cut_audio(audio, timestamps):
# Convert audio to mono if it's stereo
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
# Calculate sample rate and length of each sample
sample_rate = 16000
sample_length = 1 / sample_rate
# Initialize a list to store the cut audio segments
cut_segments = []
# Iterate over the timestamps and cut the audio accordingly
for index, timestamp in enumerate(timestamps):
start_time = timestamp["start"]
end_time = timestamp["end"]
start_sample = int(start_time / sample_length)
if index == len(timestamps) - 1:
end_sample = int(end_time / sample_length) + int(1 / sample_length)
else:
end_sample = int(end_time / sample_length)
cut_segment = audio[start_sample:end_sample]
cut_segments.append(cut_segment)
# Concatenate the cut audio segments
final_audio = np.concatenate(cut_segments)
return final_audio, sample_rate
def get_word_timestamps(segments):
word_timestamps = []
for segment in segments:
for word in segment.words:
word_info = {
'text':word.word,
'start': word.start,
'end': word.end
}
word_timestamps.append(word_info)
return word_timestamps
def get_transcription(word_timestamps):
transcription=''
for i in word_timestamps:
transcription+=i['text']
return transcription
def get_modified_timestamps(word_timestamps,filtered_text):
mod_timestemp=[]
for h in filtered_text[0].split():
c=0
for i in word_timestamps:
c=c+1
i['text']=i['text'].replace(' ','')
if h==i['text']:
mod_timestemp.append(i)
break
mod_timestemp.pop()
return mod_timestemp
def filterText(text, model,tokenizer):
device = 'cpu'
model = model.to(device)
text_encoding = tokenizer(
text,
max_length = 512,
padding = 'max_length',
truncation = True,
return_attention_mask = True,
add_special_tokens = True,
return_tensors = 'pt'
)
generated_ids = model.generate(
input_ids=text_encoding['input_ids'].to(device),
attention_mask=text_encoding['attention_mask'].to(device),
max_new_tokens=150,
no_repeat_ngram_size=2,
min_new_tokens= 1 ,
repetition_penalty=2.0,
length_penalty=0.5,
num_beams = 10,
num_return_sequences=1,
)
preds = [
tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for gen_id in generated_ids
]
return preds