|
|
import logging |
|
|
import os |
|
|
import tarfile |
|
|
import zipfile |
|
|
from typing import Any, List, Optional |
|
|
|
|
|
import torchaudio |
|
|
|
|
|
_LG = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _extract_tar(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: |
|
|
if to_path is None: |
|
|
to_path = os.path.dirname(from_path) |
|
|
with tarfile.open(from_path, "r") as tar: |
|
|
files = [] |
|
|
for file_ in tar: |
|
|
file_path = os.path.join(to_path, file_.name) |
|
|
if file_.isfile(): |
|
|
files.append(file_path) |
|
|
if os.path.exists(file_path): |
|
|
_LG.info("%s already extracted.", file_path) |
|
|
if not overwrite: |
|
|
continue |
|
|
tar.extract(file_, to_path) |
|
|
return files |
|
|
|
|
|
|
|
|
def _extract_zip(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: |
|
|
if to_path is None: |
|
|
to_path = os.path.dirname(from_path) |
|
|
|
|
|
with zipfile.ZipFile(from_path, "r") as zfile: |
|
|
files = zfile.namelist() |
|
|
for file_ in files: |
|
|
file_path = os.path.join(to_path, file_) |
|
|
if os.path.exists(file_path): |
|
|
_LG.info("%s already extracted.", file_path) |
|
|
if not overwrite: |
|
|
continue |
|
|
zfile.extract(file_, to_path) |
|
|
return files |
|
|
|
|
|
|
|
|
def _load_waveform( |
|
|
root: str, |
|
|
filename: str, |
|
|
exp_sample_rate: int, |
|
|
): |
|
|
path = os.path.join(root, filename) |
|
|
waveform, sample_rate = torchaudio.load(path) |
|
|
if exp_sample_rate != sample_rate: |
|
|
raise ValueError(f"sample rate should be {exp_sample_rate}, but got {sample_rate}") |
|
|
return waveform |
|
|
|