Spaces:
Sleeping
Sleeping
File size: 4,624 Bytes
384e020 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import os
import random
import torch
from torch.utils.data import Dataset
import torchaudio
import numpy as np
# Modify to handle dynamic target duration (8s in this case)
# def pad_audio(audio, sample_rate=16000, target_duration=8.0):
# target_length = int(sample_rate * target_duration) # Calculate target length for 8 seconds
# current_length = audio.shape[1]
# if current_length < target_length:
# padding = target_length - current_length
# audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1)
# else:
# audio = audio[:, :target_length]
# return audio
def pad_audio(audio, sample_rate=16000, target_duration=7.98):
target_length = int(sample_rate * target_duration) # Calculate target length for 8 seconds
current_length = audio.shape[1]
if current_length < target_length:
padding = target_length - current_length
audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1)
elif current_length > target_length:
# Add one frame if length is one frame more than the target
if current_length - target_length == 1:
audio = torch.cat((audio, torch.zeros(audio.shape[0], 1)), dim=1)
else:
audio = audio[:, :target_length]
return audio
# Parse labels with 10ms frame intervals for 8-second audio
def parse_labels(file_path, audio_length, sample_rate, frame_duration=0.010):
frames_per_audio = int(audio_length / frame_duration)
labels = np.zeros(frames_per_audio, dtype=np.float32)
with open(file_path, 'r') as f:
lines = f.readlines()[1:] # Skip header
for line in lines:
start, end, authenticity = line.strip().split('-')
start_time = float(start)
end_time = float(end)
if authenticity == 'F':
start_frame = int(start_time / frame_duration)
end_frame = int(end_time / frame_duration)
labels[start_frame:end_frame] = 1
# Mark 4 closest frames to boundaries
for offset in range(1, 5):
if start_frame - offset >= 0:
labels[start_frame - offset] = 1
if end_frame + offset < frames_per_audio:
labels[end_frame + offset] = 1
return labels
class AudioDataset(Dataset):
def __init__(self, audio_files, label_dir, sample_rate=16000, target_length=7.98):
self.audio_files = audio_files
self.label_dir = label_dir
self.sample_rate = sample_rate
self.target_length = target_length * sample_rate
self.raw_target_length = target_length
def __len__(self):
return len(self.audio_files)
def __getitem__(self, idx):
audio_path = self.audio_files[idx]
try:
waveform, sr = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
waveform = pad_audio(waveform, self.sample_rate, self.raw_target_length)
audio_filename = os.path.basename(audio_path).replace(".wav", "")
if audio_filename.startswith("RFP_R"):
labels = np.zeros(int(self.raw_target_length / 0.010), dtype=np.float32)
else:
label_path = os.path.join(self.label_dir, f"{audio_filename}.wav_labels.txt")
labels = parse_labels(label_path, self.raw_target_length, self.sample_rate).astype(np.float32)
return waveform, torch.tensor(labels, dtype=torch.float32)
except (OSError, IOError) as e:
print(f"Error opening file {audio_path}: {e}")
new_idx = random.randint(0, len(self.audio_files) - 1)
return self.__getitem__(new_idx)
def get_audio_file_paths(extrinsic_dir, intrinsic_dir, real_dir):
extrinsic_files = [os.path.join(extrinsic_dir, f) for f in os.listdir(extrinsic_dir)
if f.endswith(".wav") and not f.startswith("partial_fake")]
intrinsic_files = [os.path.join(intrinsic_dir, f) for f in os.listdir(intrinsic_dir)
if f.endswith(".wav") and not f.startswith("partial_fake")]
real_files = [os.path.join(real_dir, f) for f in os.listdir(real_dir)
if f.endswith(".wav") and not f.startswith("partial_fake")]
# Combine all audio files into a single list, ensuring valid files only
audio_files = [f for f in extrinsic_files + real_files
if os.path.basename(f).startswith(("extrinsic"))]
return audio_files |