| import contextlib |
| import functools |
| import hashlib |
| import logging |
| import os |
|
|
| import requests |
| import torch |
| import tqdm |
|
|
| from TTS.tts.layers.bark.model import GPT, GPTConfig |
| from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig |
|
|
| if ( |
| torch.cuda.is_available() |
| and hasattr(torch.cuda, "amp") |
| and hasattr(torch.cuda.amp, "autocast") |
| and torch.cuda.is_bf16_supported() |
| ): |
| autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) |
| else: |
|
|
| @contextlib.contextmanager |
| def autocast(): |
| yield |
|
|
|
|
| |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): |
| logger.warning( |
| "torch version does not support flash attention. You will get significantly faster" |
| + " inference speed by upgrade torch to newest version / nightly." |
| ) |
|
|
|
|
| def _md5(fname): |
| hash_md5 = hashlib.md5() |
| with open(fname, "rb") as f: |
| for chunk in iter(lambda: f.read(4096), b""): |
| hash_md5.update(chunk) |
| return hash_md5.hexdigest() |
|
|
|
|
| def _download(from_s3_path, to_local_path, CACHE_DIR): |
| os.makedirs(CACHE_DIR, exist_ok=True) |
| response = requests.get(from_s3_path, stream=True) |
| total_size_in_bytes = int(response.headers.get("content-length", 0)) |
| block_size = 1024 |
| progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) |
| with open(to_local_path, "wb") as file: |
| for data in response.iter_content(block_size): |
| progress_bar.update(len(data)) |
| file.write(data) |
| progress_bar.close() |
| if total_size_in_bytes not in [0, progress_bar.n]: |
| raise ValueError("ERROR, something went wrong") |
|
|
|
|
| class InferenceContext: |
| def __init__(self, benchmark=False): |
| |
| self._chosen_cudnn_benchmark = benchmark |
| self._cudnn_benchmark = None |
|
|
| def __enter__(self): |
| self._cudnn_benchmark = torch.backends.cudnn.benchmark |
| torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark |
|
|
| def __exit__(self, exc_type, exc_value, exc_traceback): |
| torch.backends.cudnn.benchmark = self._cudnn_benchmark |
|
|
|
|
| if torch.cuda.is_available(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| @contextlib.contextmanager |
| def inference_mode(): |
| with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): |
| yield |
|
|
|
|
| def clear_cuda_cache(): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
|
|
| def load_model(ckpt_path, device, config, model_type="text"): |
| logger.info(f"loading {model_type} model from {ckpt_path}...") |
|
|
| if device == "cpu": |
| logger.warning("No GPU being used. Careful, Inference might be extremely slow!") |
| if model_type == "text": |
| ConfigClass = GPTConfig |
| ModelClass = GPT |
| elif model_type == "coarse": |
| ConfigClass = GPTConfig |
| ModelClass = GPT |
| elif model_type == "fine": |
| ConfigClass = FineGPTConfig |
| ModelClass = FineGPT |
| else: |
| raise NotImplementedError() |
| if ( |
| not config.USE_SMALLER_MODELS |
| and os.path.exists(ckpt_path) |
| and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] |
| ): |
| logger.warning(f"found outdated {model_type} model, removing...") |
| os.remove(ckpt_path) |
| if not os.path.exists(ckpt_path): |
| logger.info(f"{model_type} model not found, downloading...") |
| _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) |
|
|
| checkpoint = torch.load(ckpt_path, map_location=device) |
| |
| model_args = checkpoint["model_args"] |
| if "input_vocab_size" not in model_args: |
| model_args["input_vocab_size"] = model_args["vocab_size"] |
| model_args["output_vocab_size"] = model_args["vocab_size"] |
| del model_args["vocab_size"] |
|
|
| gptconf = ConfigClass(**checkpoint["model_args"]) |
| if model_type == "text": |
| config.semantic_config = gptconf |
| elif model_type == "coarse": |
| config.coarse_config = gptconf |
| elif model_type == "fine": |
| config.fine_config = gptconf |
|
|
| model = ModelClass(gptconf) |
| state_dict = checkpoint["model"] |
| |
| unwanted_prefix = "_orig_mod." |
| for k, _ in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) |
| extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) |
| extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) |
| missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |
| missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) |
| if len(extra_keys) != 0: |
| raise ValueError(f"extra keys found: {extra_keys}") |
| if len(missing_keys) != 0: |
| raise ValueError(f"missing keys: {missing_keys}") |
| model.load_state_dict(state_dict, strict=False) |
| n_params = model.get_num_params() |
| val_loss = checkpoint["best_val_loss"].item() |
| logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") |
| model.eval() |
| model.to(device) |
| del checkpoint, state_dict |
| clear_cuda_cache() |
| return model, config |
|
|