| from transformers import SeamlessM4TFeatureExtractor
|
| from transformers import Wav2Vec2BertModel
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| import librosa
|
| import os
|
| import pickle
|
| import math
|
| import json
|
| import safetensors
|
| import json5
|
|
|
| from startts.examples.ftchar.models.codec.kmeans.repcodec_model import RepCodec
|
|
|
| class JsonHParams:
|
| def __init__(self, **kwargs):
|
| for k, v in kwargs.items():
|
| if type(v) == dict:
|
| v = JsonHParams(**v)
|
| self[k] = v
|
|
|
| def keys(self):
|
| return self.__dict__.keys()
|
|
|
| def items(self):
|
| return self.__dict__.items()
|
|
|
| def values(self):
|
| return self.__dict__.values()
|
|
|
| def __len__(self):
|
| return len(self.__dict__)
|
|
|
| def __getitem__(self, key):
|
| return getattr(self, key)
|
|
|
| def __setitem__(self, key, value):
|
| return setattr(self, key, value)
|
|
|
| def __contains__(self, key):
|
| return key in self.__dict__
|
|
|
| def __repr__(self):
|
| return self.__dict__.__repr__()
|
|
|
|
|
| def _load_config(config_fn, lowercase=False):
|
| """Load configurations into a dictionary
|
|
|
| Args:
|
| config_fn (str): path to configuration file
|
| lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
|
|
|
| Returns:
|
| dict: dictionary that stores configurations
|
| """
|
| with open(config_fn, "r") as f:
|
| data = f.read()
|
| config_ = json5.loads(data)
|
| if "base_config" in config_:
|
|
|
| p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
|
| p_config_ = _load_config(p_config_path)
|
| config_ = override_config(p_config_, config_)
|
| if lowercase:
|
|
|
| config_ = get_lowercase_keys_config(config_)
|
| return config_
|
|
|
|
|
| def load_config(config_fn, lowercase=False):
|
| """Load configurations into a dictionary
|
|
|
| Args:
|
| config_fn (str): path to configuration file
|
| lowercase (bool, optional): _description_. Defaults to False.
|
|
|
| Returns:
|
| JsonHParams: an object that stores configurations
|
| """
|
| config_ = _load_config(config_fn, lowercase=lowercase)
|
|
|
| cfg = JsonHParams(**config_)
|
| return cfg
|
|
|
| class Extract_wav2vectbert:
|
| def __init__(self,device):
|
|
|
| self.semantic_model = Wav2Vec2BertModel.from_pretrained("./MaskGCT_model/w2v_bert/")
|
| self.semantic_model.eval()
|
| self.semantic_model.to(device)
|
| self.stat_mean_var = torch.load("./MaskGCT_model/wav2vec2bert_stats.pt")
|
| self.semantic_mean = self.stat_mean_var["mean"]
|
| self.semantic_std = torch.sqrt(self.stat_mean_var["var"])
|
| self.semantic_mean = self.semantic_mean.to(device)
|
| self.semantic_std = self.semantic_std.to(device)
|
| self.processor = SeamlessM4TFeatureExtractor.from_pretrained(
|
| "./MaskGCT_model/w2v_bert/")
|
| self.device = device
|
|
|
| cfg_maskgct = load_config('./MaskGCT_model/maskgct.json')
|
| cfg = cfg_maskgct.model.semantic_codec
|
| self.semantic_code_ckpt = r'./MaskGCT_model/semantic_codec/model.safetensors'
|
| self.semantic_codec = RepCodec(cfg=cfg)
|
| self.semantic_codec.eval()
|
| self.semantic_codec.to(device)
|
| safetensors.torch.load_model(self.semantic_codec, self.semantic_code_ckpt)
|
|
|
| @torch.no_grad()
|
| def extract_features(self, speech):
|
| inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt")
|
| input_features = inputs["input_features"]
|
| attention_mask = inputs["attention_mask"]
|
| return input_features, attention_mask
|
|
|
| @torch.no_grad()
|
| def extract_semantic_code(self, input_features, attention_mask):
|
| vq_emb = self.semantic_model(
|
| input_features=input_features,
|
| attention_mask=attention_mask,
|
| output_hidden_states=True,
|
| )
|
| feat = vq_emb.hidden_states[17]
|
| feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat)
|
|
|
| semantic_code, rec_feat = self.semantic_codec.quantize(feat)
|
| return semantic_code, rec_feat
|
|
|
| def feature_extract(self, prompt_speech):
|
|
|
| input_features, attention_mask = self.extract_features(prompt_speech)
|
| input_features = input_features.to(self.device)
|
| attention_mask = attention_mask.to(self.device)
|
| semantic_code, rec_feat = self.extract_semantic_code(input_features, attention_mask)
|
| return semantic_code,rec_feat
|
|
|
| if __name__=='__main__':
|
| speech_path = 'test/magi1.wav'
|
| speech = librosa.load(speech_path, sr=16000)[0]
|
| speech = np.c_[speech,speech,speech].T
|
| print(speech.shape)
|
|
|
| Extract_feature = Extract_wav2vectbert('cuda:0')
|
| semantic_code,rec_feat = Extract_feature.feature_extract(speech)
|
| print(semantic_code.shape,rec_feat.shape)
|
|
|
|
|