xuan3986's picture
Upload 111 files
03022ee verified
import csv
import yaml
import codecs
import json
import torch
import torchaudio
import numpy as np
def load_yaml(yaml_path):
with open(yaml_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def load_data_csv(fpath):
with open(fpath, newline="") as f:
result = {}
reader = csv.DictReader(f, skipinitialspace=True)
for row in reader:
if 'ID' not in row:
raise KeyError(
"CSV file has to have an 'ID' field, with unique ids for all data points."
)
data_id = row["ID"]
del row["ID"]
if data_id in result:
raise ValueError(f"Duplicate id: {data_id}")
result[data_id] = row
return result
def load_data_list(fpath):
with open(fpath) as f:
rows = [i.strip() for i in f.readlines()]
result = {idx: row for idx, row in enumerate(rows)}
return result
def load_wav_scp(fpath):
with open(fpath) as f:
rows = [i.strip() for i in f.readlines()]
result = {i.split()[0]: i.split()[1] for i in rows}
return result
def load_json_file(json_file):
with codecs.open(json_file, "r", encoding="utf-8") as fr:
data_dict = json.load(fr)
return data_dict
def load_trans7time_list(filename):
"""
trans7time: (spk_id, st, ed, content)
"""
with open(filename, "r") as fr:
trans7time_list = []
lines = fr.readlines()
for line in lines:
trans7time_list.append(line.strip().split())
result_trans7time_list = []
for index, item in enumerate(trans7time_list):
if len(item) <= 2:
raise ValueError(f"filename {filename}: item - {index} = {item}")
if len(item) == 3:
st = float(item[1])
ed = float(item[2])
result_trans7time_list.append((
item[0], st, ed, ""
))
else:
result_trans7time_list.append((
item[0], float(item[1]), float(item[2]), "".join(item[3:])
))
return result_trans7time_list
def write_json_file(json_file, data):
assert str(json_file).endswith(".json") or str(json_file).endswith(".JSON")
with codecs.open(json_file, "w", encoding="utf-8") as fw:
json.dump(data, fw, indent=2, ensure_ascii=False)
def write_wav_scp(fpath, wav_scp):
with open(fpath, "w") as f:
for key, value in wav_scp.items():
f.write(f"{key} {value}\n")
def write_trans7time_list(fpath, trans7time_list):
"""
trans7time_list: [(spk_id, start_time, end_time, text)]
"""
with open(fpath, 'w') as fw:
for spk_id, start_time, end_time, text in trans7time_list:
text = text.replace("\n", "").replace("\r", "")
fw.write(f'{spk_id} {start_time} {end_time} {text}\n')
def load_audio(input, ori_fs=None, obj_fs=None):
if isinstance(input, str):
wav, fs = torchaudio.load(input)
wav = wav.mean(dim=0, keepdim=True)
if obj_fs is not None and fs != obj_fs:
wav = torchaudio.functional.resample(wav, orig_freq=fs, new_freq=obj_fs)
return wav
elif isinstance(input, np.ndarray) or isinstance(input, torch.Tensor):
wav = torch.from_numpy(input) if isinstance(input, np.ndarray) else input
if wav.dtype in (torch.int16, torch.int32, torch.int64):
wav = wav.type(torch.float32)
wav = wav / 32768
wav = wav.type(torch.float32)
assert wav.ndim <= 2
if wav.ndim == 2:
if wav.shape[0] > wav.shape[1]:
wav = torch.transpose(wav, 0, 1)
wav = wav.mean(dim=0, keepdim=True)
if wav.ndim == 1:
wav = wav.unsqueeze(0)
if ori_fs is not None and obj_fs is not None and ori_fs!=obj_fs:
wav = torchaudio.functional.resample(wav, orig_freq=ori_fs, new_freq=obj_fs)
return wav
else:
return input