Spaces:
Build error
Build error
| import errno | |
| import gc | |
| import os | |
| import sys | |
| import torch | |
| # from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig | |
| from transformers import Wav2Vec2FeatureExtractor, WavLMModel | |
| from .s2f_dir.src import autoencoder as ae | |
| from .util import * | |
| g_fix_seed = False | |
| g_audio_processor = None | |
| g_audio_encoder = None | |
| class ModelInfo: | |
| def __init__( | |
| self, | |
| model, | |
| audio_processor, | |
| audio_encoder, | |
| args, | |
| device, | |
| work_root_path, | |
| config_path, | |
| checkpoint_path, | |
| verbose=False, | |
| ): | |
| self.model = model | |
| self.audio_processor = audio_processor | |
| self.audio_encoder = audio_encoder | |
| self.args = args | |
| self.device = device | |
| # snow : ์๋๋ debuging ์ ์ํด ์ ์ฅํด ๋๋ ๊ฒ | |
| self.work_root_path = work_root_path | |
| self.config_path = config_path | |
| self.checkpoint_path = checkpoint_path | |
| self.verbose = verbose | |
| def __del__(self): | |
| if self.verbose: | |
| print("del model , gc:", sys.getrefcount(self.model)) | |
| del self.model | |
| if self.args.model_type == "stf_v3": | |
| del self.audio_encoder | |
| del self.audio_processor | |
| def __init_fix_seed(random_seed, verbose=False): | |
| global g_fix_seed | |
| if g_fix_seed == True: | |
| return | |
| if verbose: | |
| print("fix seed") | |
| fix_seed(random_seed) | |
| g_fix_seed = True | |
| def create_model( | |
| config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None | |
| ): | |
| __init_fix_seed(random_seed=1234, verbose=verbose) | |
| global g_audio_encoder | |
| global g_audio_processor | |
| if verbose: | |
| print("load model") | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path) | |
| args = read_config(config_path) | |
| if args.model_type and args.model_type == "remote": | |
| return ModelInfo( | |
| model=None, | |
| audio_processor=None, | |
| audio_encoder=None, | |
| args=args, | |
| device=device, | |
| work_root_path=work_root_path, | |
| config_path=config_path, | |
| checkpoint_path=checkpoint_path, | |
| verbose=verbose, | |
| ) | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError( | |
| errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path | |
| ) | |
| if args.model_type: | |
| model = ae.Speech2Face( | |
| 3, | |
| (3, args.img_size, args.img_size), | |
| (1, 96, args.mel_step_size), | |
| args.model_type, | |
| ) | |
| else: | |
| model = ae.Speech2Face( | |
| 3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1" | |
| ) | |
| if len(args.model_type) == 0: # snow: ๋์ค์ ์๊ธด ์ค์ ์ด์ด์ ์ด ํญ๋ชฉ์ด ์์ ์๊ฐ ์๋ค. | |
| args.model_type = "stf_v1" | |
| if args.model_type == "stf_v3": | |
| if g_audio_encoder == None: | |
| if wavlm_path is None: | |
| wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm" | |
| if verbose: | |
| print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}") | |
| g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path) | |
| g_audio_encoder = WavLMModel.from_pretrained(wavlm_path) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| if "state_dict" in checkpoint: | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| if device == "cuda" and torch.cuda.device_count() > 1: | |
| gpus = list(range(torch.cuda.device_count())) | |
| print("Multi GPU activate, gpus : ", gpus) | |
| model = torch.nn.DataParallel(model, device_ids=gpus) | |
| model.cuda(0) # to(device) | |
| model.eval() | |
| if args.model_type == "stf_v3": | |
| g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus) | |
| #g_audio_encoder.to(device) | |
| g_audio_encoder.cuda(0) | |
| g_audio_encoder.eval() | |
| else: | |
| #model.to(device).eval() | |
| model.cuda(0).eval() | |
| if args.model_type == "stf_v3": | |
| #g_audio_encoder.to(device).eval() | |
| g_audio_encoder.cuda(0).eval() | |
| model_data = ModelInfo( | |
| model=model, | |
| audio_processor=g_audio_processor, | |
| audio_encoder=g_audio_encoder, | |
| args=args, | |
| device=device, | |
| work_root_path=work_root_path, | |
| config_path=config_path, | |
| checkpoint_path=checkpoint_path, | |
| verbose=verbose, | |
| ) | |
| del checkpoint | |
| gc.collect() | |
| if verbose: | |
| print("load model complete") | |
| return model_data | |