File size: 3,418 Bytes
ebba35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Audio preprocessor - normalize audio for analysis.
"""
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
import tempfile
import os


class AudioPreprocessor:
    """Normalize audio to standard format for analysis."""
    
    TARGET_SAMPLE_RATE = 16000
    TARGET_CHANNELS = 1
    
    def __init__(self):
        pass
    
    def load_audio(self, audio_path: str) -> Tuple[torch.Tensor, int]:
        """
        Load audio file.

        Returns:
            Tuple of (waveform, sample_rate)
        """
        # Use soundfile backend to avoid torchcodec dependency
        waveform, sample_rate = torchaudio.load(audio_path, backend="soundfile")
        return waveform, sample_rate
    
    def normalize(self, waveform: torch.Tensor, sample_rate: int) -> Tuple[torch.Tensor, int]:
        """
        Normalize audio to mono, 16kHz, normalized amplitude.
        
        Returns:
            Tuple of (normalized_waveform, target_sample_rate)
        """
        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample to 16kHz
        if sample_rate != self.TARGET_SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(
                orig_freq=sample_rate,
                new_freq=self.TARGET_SAMPLE_RATE
            )
            waveform = resampler(waveform)
        
        # Normalize amplitude
        max_amp = waveform.abs().max()
        if max_amp > 0:
            waveform = waveform / max_amp * 0.95
        
        return waveform, self.TARGET_SAMPLE_RATE
    
    def process_file(self, audio_path: str, output_path: Optional[str] = None) -> Tuple[torch.Tensor, int, dict]:
        """
        Load and normalize audio file.
        
        Returns:
            Tuple of (waveform, sample_rate, metadata)
        """
        # Load
        waveform, orig_sr = self.load_audio(audio_path)
        orig_duration = waveform.shape[1] / orig_sr
        orig_channels = waveform.shape[0]
        
        # Normalize
        waveform, sample_rate = self.normalize(waveform, orig_sr)
        
        # Save if output path provided
        if output_path:
            torchaudio.save(output_path, waveform, sample_rate)
        
        metadata = {
            'original_sample_rate': orig_sr,
            'original_channels': orig_channels,
            'original_duration': orig_duration,
            'normalized_sample_rate': sample_rate,
            'normalized_duration': waveform.shape[1] / sample_rate
        }
        
        return waveform, sample_rate, metadata
    
    def get_duration(self, waveform: torch.Tensor, sample_rate: int) -> float:
        """Get duration in seconds."""
        return waveform.shape[1] / sample_rate
    
    def save_audio(self, waveform: torch.Tensor, sample_rate: int, output_path: str):
        """Save audio to file."""
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        torchaudio.save(output_path, waveform, sample_rate)
    
    def extract_segment(self, waveform: torch.Tensor, sample_rate: int, 
                        start: float, end: float) -> torch.Tensor:
        """Extract segment from waveform."""
        start_sample = int(start * sample_rate)
        end_sample = int(end * sample_rate)
        return waveform[:, start_sample:end_sample]