Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchaudio | |
| import numpy as np | |
| # Modify to handle dynamic target duration (8s in this case) | |
| # def pad_audio(audio, sample_rate=16000, target_duration=8.0): | |
| # target_length = int(sample_rate * target_duration) # Calculate target length for 8 seconds | |
| # current_length = audio.shape[1] | |
| # if current_length < target_length: | |
| # padding = target_length - current_length | |
| # audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1) | |
| # else: | |
| # audio = audio[:, :target_length] | |
| # return audio | |
| def pad_audio(audio, sample_rate=16000, target_duration=7.98): | |
| target_length = int(sample_rate * target_duration) # Calculate target length for 8 seconds | |
| current_length = audio.shape[1] | |
| if current_length < target_length: | |
| padding = target_length - current_length | |
| audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1) | |
| elif current_length > target_length: | |
| # Add one frame if length is one frame more than the target | |
| if current_length - target_length == 1: | |
| audio = torch.cat((audio, torch.zeros(audio.shape[0], 1)), dim=1) | |
| else: | |
| audio = audio[:, :target_length] | |
| return audio | |
| # Parse labels with 10ms frame intervals for 8-second audio | |
| def parse_labels(file_path, audio_length, sample_rate, frame_duration=0.010): | |
| frames_per_audio = int(audio_length / frame_duration) | |
| labels = np.zeros(frames_per_audio, dtype=np.float32) | |
| with open(file_path, 'r') as f: | |
| lines = f.readlines()[1:] # Skip header | |
| for line in lines: | |
| start, end, authenticity = line.strip().split('-') | |
| start_time = float(start) | |
| end_time = float(end) | |
| if authenticity == 'F': | |
| start_frame = int(start_time / frame_duration) | |
| end_frame = int(end_time / frame_duration) | |
| labels[start_frame:end_frame] = 1 | |
| # Mark 4 closest frames to boundaries | |
| for offset in range(1, 5): | |
| if start_frame - offset >= 0: | |
| labels[start_frame - offset] = 1 | |
| if end_frame + offset < frames_per_audio: | |
| labels[end_frame + offset] = 1 | |
| return labels | |
| class AudioDataset(Dataset): | |
| def __init__(self, audio_files, label_dir, sample_rate=16000, target_length=7.98): | |
| self.audio_files = audio_files | |
| self.label_dir = label_dir | |
| self.sample_rate = sample_rate | |
| self.target_length = target_length * sample_rate | |
| self.raw_target_length = target_length | |
| def __len__(self): | |
| return len(self.audio_files) | |
| def __getitem__(self, idx): | |
| audio_path = self.audio_files[idx] | |
| try: | |
| waveform, sr = torchaudio.load(audio_path) | |
| waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform) | |
| waveform = pad_audio(waveform, self.sample_rate, self.raw_target_length) | |
| audio_filename = os.path.basename(audio_path).replace(".wav", "") | |
| if audio_filename.startswith("RFP_R"): | |
| labels = np.zeros(int(self.raw_target_length / 0.010), dtype=np.float32) | |
| else: | |
| label_path = os.path.join(self.label_dir, f"{audio_filename}.wav_labels.txt") | |
| labels = parse_labels(label_path, self.raw_target_length, self.sample_rate).astype(np.float32) | |
| return waveform, torch.tensor(labels, dtype=torch.float32) | |
| except (OSError, IOError) as e: | |
| print(f"Error opening file {audio_path}: {e}") | |
| new_idx = random.randint(0, len(self.audio_files) - 1) | |
| return self.__getitem__(new_idx) | |
| def get_audio_file_paths(extrinsic_dir, intrinsic_dir, real_dir): | |
| extrinsic_files = [os.path.join(extrinsic_dir, f) for f in os.listdir(extrinsic_dir) | |
| if f.endswith(".wav") and not f.startswith("partial_fake")] | |
| intrinsic_files = [os.path.join(intrinsic_dir, f) for f in os.listdir(intrinsic_dir) | |
| if f.endswith(".wav") and not f.startswith("partial_fake")] | |
| real_files = [os.path.join(real_dir, f) for f in os.listdir(real_dir) | |
| if f.endswith(".wav") and not f.startswith("partial_fake")] | |
| # Combine all audio files into a single list, ensuring valid files only | |
| audio_files = [f for f in extrinsic_files + real_files | |
| if os.path.basename(f).startswith(("extrinsic"))] | |
| return audio_files |