File size: 4,025 Bytes
03022ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import csv
import yaml
import codecs
import json
import torch
import torchaudio
import numpy as np


def load_yaml(yaml_path):
    with open(yaml_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


def load_data_csv(fpath):
    with open(fpath, newline="") as f:
        result = {}
        reader = csv.DictReader(f, skipinitialspace=True)
        for row in reader:
            if 'ID' not in row:
                raise KeyError(
                    "CSV file has to have an 'ID' field, with unique ids for all data points."
                )

            data_id = row["ID"]
            del row["ID"]

            if data_id in result:
                raise ValueError(f"Duplicate id: {data_id}")
            result[data_id] = row
    return result


def load_data_list(fpath):
    with open(fpath) as f:
        rows = [i.strip() for i in f.readlines()]
        result = {idx: row for idx, row in enumerate(rows)}
    return result


def load_wav_scp(fpath):
    with open(fpath) as f:
        rows = [i.strip() for i in f.readlines()]
        result = {i.split()[0]: i.split()[1] for i in rows}
    return result


def load_json_file(json_file):
    with codecs.open(json_file, "r", encoding="utf-8") as fr:
        data_dict = json.load(fr)
    return data_dict


def load_trans7time_list(filename):
    """
        trans7time: (spk_id, st, ed, content)
    """
    with open(filename, "r") as fr:
        trans7time_list = []
        lines = fr.readlines()
        for line in lines:
            trans7time_list.append(line.strip().split())
        result_trans7time_list = []
    for index, item in enumerate(trans7time_list):
        if len(item) <= 2:
            raise ValueError(f"filename {filename}: item - {index} = {item}")
        if len(item) == 3:
            st = float(item[1])
            ed = float(item[2])
            result_trans7time_list.append((
                item[0], st, ed, ""
            ))
        else:
            result_trans7time_list.append((
                item[0], float(item[1]), float(item[2]), "".join(item[3:])
            ))
    return result_trans7time_list


def write_json_file(json_file, data):
    assert str(json_file).endswith(".json") or str(json_file).endswith(".JSON")
    with codecs.open(json_file, "w", encoding="utf-8") as fw:
        json.dump(data, fw, indent=2, ensure_ascii=False)


def write_wav_scp(fpath, wav_scp):
    with open(fpath, "w") as f:
        for key, value in wav_scp.items():
            f.write(f"{key} {value}\n")


def write_trans7time_list(fpath, trans7time_list):
    """
        trans7time_list: [(spk_id, start_time, end_time, text)]
    """
    with open(fpath, 'w') as fw:
        for spk_id, start_time, end_time, text in trans7time_list:
            text = text.replace("\n", "").replace("\r", "")
            fw.write(f'{spk_id} {start_time} {end_time} {text}\n')

def load_audio(input, ori_fs=None, obj_fs=None):            
    if isinstance(input, str):
        wav, fs = torchaudio.load(input)
        wav = wav.mean(dim=0, keepdim=True)
        if obj_fs is not None and fs != obj_fs:
            wav = torchaudio.functional.resample(wav, orig_freq=fs, new_freq=obj_fs)
        return wav
    elif isinstance(input, np.ndarray) or isinstance(input, torch.Tensor):
        wav = torch.from_numpy(input) if isinstance(input, np.ndarray) else input
        if wav.dtype in (torch.int16, torch.int32, torch.int64):
            wav = wav.type(torch.float32)
            wav = wav / 32768
        wav = wav.type(torch.float32)
        assert wav.ndim <= 2
        if wav.ndim == 2:
            if wav.shape[0] > wav.shape[1]:
                wav = torch.transpose(wav, 0, 1)
            wav = wav.mean(dim=0, keepdim=True)
        if wav.ndim == 1:
            wav = wav.unsqueeze(0)
        if ori_fs is not None and obj_fs is not None and ori_fs!=obj_fs:
            wav = torchaudio.functional.resample(wav, orig_freq=ori_fs, new_freq=obj_fs)
        return wav
    else:
        return input