File size: 3,367 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import annotations

from pathlib import Path
import platform
import shutil
import subprocess
import wave
try:
    import winsound
except ImportError:  # pragma: no cover - only exercised on non-Windows hosts
    winsound = None

import numpy as np
import soundfile as sf
import torch
import torchaudio


def select_best_device(explicit: str | None = None) -> str:
    if explicit:
        return explicit
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def select_runtime_dtype(device: str, preferred: torch.dtype | None = None) -> torch.dtype:
    if preferred is not None:
        return preferred
    if device == "cuda":
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    return torch.float32


def ensure_parent_dir(path: str | Path) -> Path:
    resolved = Path(path)
    resolved.parent.mkdir(parents=True, exist_ok=True)
    return resolved


def save_waveform(path: str | Path, waveform: torch.Tensor, sample_rate: int) -> Path:
    output_path = ensure_parent_dir(path)
    audio = waveform.detach().cpu()
    if audio.dim() == 1:
        audio = audio.unsqueeze(0)
    try:
        torchaudio.save(str(output_path), audio, sample_rate)
    except Exception:
        audio = audio.clamp(-1.0, 1.0)
        pcm16 = (audio.numpy() * 32767.0).astype(np.int16)
        with wave.open(str(output_path), "wb") as handle:
            handle.setnchannels(int(pcm16.shape[0]))
            handle.setsampwidth(2)
            handle.setframerate(int(sample_rate))
            handle.writeframes(pcm16.T.tobytes())
    return output_path


def load_waveform(path: str | Path) -> tuple[torch.Tensor, int]:
    try:
        waveform, sample_rate = torchaudio.load(str(path))
        return waveform, sample_rate
    except Exception:
        audio, sample_rate = sf.read(str(path), always_2d=True)
        waveform = torch.from_numpy(audio.T).to(dtype=torch.float32)
        return waveform, int(sample_rate)


def detect_platform() -> str:
    return platform.system().lower()


def native_playback_command(audio_path: str | Path) -> list[str] | None:
    resolved = str(Path(audio_path))
    system = detect_platform()
    if system == "windows":
        return None
    if system == "darwin" and shutil.which("afplay"):
        return ["afplay", resolved]
    if system == "linux":
        for cmd in ("aplay", "paplay", "ffplay", "xdg-open"):
            if shutil.which(cmd):
                if cmd == "ffplay":
                    return [cmd, "-nodisp", "-autoexit", resolved]
                return [cmd, resolved]
    return None


def play_audio_file(audio_path: str | Path, *, block: bool = True) -> bool:
    resolved = Path(audio_path)
    if not resolved.exists():
        return False
    system = detect_platform()
    if system == "windows":
        if winsound is None:
            return False
        flags = winsound.SND_FILENAME
        if not block:
            flags |= winsound.SND_ASYNC
        winsound.PlaySound(str(resolved), flags)
        return True
    command = native_playback_command(resolved)
    if command is None:
        return False
    if block:
        subprocess.run(command, check=False)
    else:
        subprocess.Popen(command)
    return True