File size: 4,624 Bytes
384e020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import random
import torch
from torch.utils.data import Dataset
import torchaudio
import numpy as np

# Modify to handle dynamic target duration (8s in this case)
# def pad_audio(audio, sample_rate=16000, target_duration=8.0):
#     target_length = int(sample_rate * target_duration)  # Calculate target length for 8 seconds
#     current_length = audio.shape[1]
    
#     if current_length < target_length:
#         padding = target_length - current_length
#         audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1)
#     else:
#         audio = audio[:, :target_length]
    
#     return audio
def pad_audio(audio, sample_rate=16000, target_duration=7.98):
    target_length = int(sample_rate * target_duration)  # Calculate target length for 8 seconds
    current_length = audio.shape[1]
    
    if current_length < target_length:
        padding = target_length - current_length
        audio = torch.cat((audio, torch.zeros(audio.shape[0], padding)), dim=1)
    elif current_length > target_length:
        # Add one frame if length is one frame more than the target
        if current_length - target_length == 1:
            audio = torch.cat((audio, torch.zeros(audio.shape[0], 1)), dim=1)
        else:
            audio = audio[:, :target_length]
    
    return audio

# Parse labels with 10ms frame intervals for 8-second audio
def parse_labels(file_path, audio_length, sample_rate, frame_duration=0.010):
    frames_per_audio = int(audio_length / frame_duration)
    labels = np.zeros(frames_per_audio, dtype=np.float32)

    with open(file_path, 'r') as f:
        lines = f.readlines()[1:]  # Skip header
        for line in lines:
            start, end, authenticity = line.strip().split('-')
            start_time = float(start)
            end_time = float(end)

            if authenticity == 'F':
                start_frame = int(start_time / frame_duration)
                end_frame = int(end_time / frame_duration)
                labels[start_frame:end_frame] = 1
                
                # Mark 4 closest frames to boundaries
                for offset in range(1, 5):
                    if start_frame - offset >= 0:
                        labels[start_frame - offset] = 1
                    if end_frame + offset < frames_per_audio:
                        labels[end_frame + offset] = 1

    return labels

class AudioDataset(Dataset):
    def __init__(self, audio_files, label_dir, sample_rate=16000, target_length=7.98):
        self.audio_files = audio_files
        self.label_dir = label_dir
        self.sample_rate = sample_rate
        self.target_length = target_length * sample_rate
        self.raw_target_length = target_length

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        try:
            waveform, sr = torchaudio.load(audio_path)
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
            waveform = pad_audio(waveform, self.sample_rate, self.raw_target_length)

            audio_filename = os.path.basename(audio_path).replace(".wav", "")
            if audio_filename.startswith("RFP_R"):
                labels = np.zeros(int(self.raw_target_length / 0.010), dtype=np.float32)
            else:
                label_path = os.path.join(self.label_dir, f"{audio_filename}.wav_labels.txt")
                labels = parse_labels(label_path, self.raw_target_length, self.sample_rate).astype(np.float32)

            return waveform, torch.tensor(labels, dtype=torch.float32)
        
        except (OSError, IOError) as e:
            print(f"Error opening file {audio_path}: {e}")
            new_idx = random.randint(0, len(self.audio_files) - 1)
            return self.__getitem__(new_idx)


def get_audio_file_paths(extrinsic_dir, intrinsic_dir, real_dir):
    extrinsic_files = [os.path.join(extrinsic_dir, f) for f in os.listdir(extrinsic_dir)
                       if f.endswith(".wav") and not f.startswith("partial_fake")]
    intrinsic_files = [os.path.join(intrinsic_dir, f) for f in os.listdir(intrinsic_dir)
                       if f.endswith(".wav") and not f.startswith("partial_fake")]
    real_files = [os.path.join(real_dir, f) for f in os.listdir(real_dir)
                  if f.endswith(".wav") and not f.startswith("partial_fake")]
    
    # Combine all audio files into a single list, ensuring valid files only
    audio_files = [f for f in extrinsic_files + real_files 
                   if os.path.basename(f).startswith(("extrinsic"))]
    return audio_files