| import argparse |
| import numpy as np |
|
|
| from rich import print |
|
|
| from bark_infinity import config |
|
|
| logger = config.logger |
|
|
| from bark_infinity import generation |
| from bark_infinity import api |
|
|
| from bark_infinity import text_processing |
| import time |
|
|
| import random |
|
|
| text_prompts_in_this_file = [] |
|
|
|
|
| import torch |
| from torch.utils import collect_env |
|
|
|
|
| try: |
| text_prompts_in_this_file.append( |
| f"It's {text_processing.current_date_time_in_words()} And if you're hearing this, Bark is working. But you didn't provide any text" |
| ) |
| except Exception as e: |
| print(f"An error occurred: {e}") |
|
|
| text_prompt = """ |
| In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move. However, Bark is working. |
| """ |
| text_prompts_in_this_file.append(text_prompt) |
|
|
| text_prompt = """ |
| A common mistake that people make when trying to design something completely foolproof is to underestimate the ingenuity of complete fools. |
| """ |
| text_prompts_in_this_file.append(text_prompt) |
|
|
|
|
| def get_group_args(group_name, updated_args): |
| |
| updated_args_dict = vars(updated_args) |
|
|
| group_args = {} |
| for key, value in updated_args_dict.items(): |
| if key in dict(config.DEFAULTS[group_name]): |
| group_args[key] = value |
| return group_args |
|
|
|
|
| def main(args): |
| if args.loglevel is not None: |
| logger.setLevel(args.loglevel) |
|
|
| if args.OFFLOAD_CPU is not None: |
| generation.OFFLOAD_CPU = args.OFFLOAD_CPU |
| |
| else: |
| if generation.get_SUNO_USE_DIRECTML() is not True: |
| generation.OFFLOAD_CPU = True |
| if args.USE_SMALL_MODELS is not None: |
| generation.USE_SMALL_MODELS = args.USE_SMALL_MODELS |
| |
| if args.GLOBAL_ENABLE_MPS is not None: |
| generation.GLOBAL_ENABLE_MPS = args.GLOBAL_ENABLE_MPS |
| |
|
|
| if not args.silent: |
| if args.detailed_gpu_report or args.show_all_reports: |
| print(api.startup_status_report(quick=False)) |
| elif not args.text_prompt and not args.prompt_file: |
| print(api.startup_status_report(quick=True)) |
| if args.detailed_hugging_face_cache_report or args.show_all_reports: |
| print(api.hugging_face_cache_report()) |
| if args.detailed_cuda_report or args.show_all_reports: |
| print(api.cuda_status_report()) |
| if args.detailed_numpy_report: |
| print(api.numpy_report()) |
| if args.run_numpy_benchmark or args.show_all_reports: |
| from bark_infinity.debug import numpy_benchmark |
|
|
| numpy_benchmark() |
|
|
| if args.list_speakers: |
| api.list_speakers() |
| return |
|
|
| if args.render_npz_samples: |
| api.render_npz_samples() |
| return |
|
|
| if args.text_prompt: |
| text_prompts_to_process = [args.text_prompt] |
| elif args.prompt_file: |
| text_file = text_processing.load_text(args.prompt_file) |
| if text_file is None: |
| logger.error(f"Error loading file: {args.prompt_file}") |
| return |
| text_prompts_to_process = text_processing.split_text( |
| text_processing.load_text(args.prompt_file), |
| args.split_input_into_separate_prompts_by, |
| args.split_input_into_separate_prompts_by_value, |
| ) |
|
|
| print(f"\nProcessing file: {args.prompt_file}") |
| print(f" Looks like: {len(text_prompts_to_process)} prompt(s)") |
|
|
| else: |
| print("No --text_prompt or --prompt_file specified, using test prompt.") |
| text_prompts_to_process = random.sample(text_prompts_in_this_file, 2) |
|
|
| things = len(text_prompts_to_process) + args.output_iterations |
| if things > 10: |
| if args.dry_run is False: |
| print( |
| f"WARNING: You are about to process {things} prompts. Consider using '--dry-run' to test things first." |
| ) |
|
|
| |
| print("Loading Bark models...") |
| if not args.dry_run and generation.get_SUNO_USE_DIRECTML() is not True: |
| generation.preload_models( |
| args.text_use_gpu, |
| args.text_use_small, |
| args.coarse_use_gpu, |
| args.coarse_use_small, |
| args.fine_use_gpu, |
| args.fine_use_small, |
| args.codec_use_gpu, |
| args.force_reload, |
| ) |
|
|
| print("Done.") |
|
|
| for idx, text_prompt in enumerate(text_prompts_to_process, start=1): |
| if len(text_prompts_to_process) > 1: |
| print(f"\nPrompt {idx}/{len(text_prompts_to_process)}:") |
|
|
| |
| for iteration in range(1, args.output_iterations + 1): |
| if args.output_iterations > 1: |
| print(f"\nIteration {iteration} of {args.output_iterations}.") |
| if iteration == 1: |
| print("ss", text_prompt) |
|
|
| args.current_iteration = iteration |
| args.text_prompt = text_prompt |
| args_dict = vars(args) |
|
|
| api.generate_audio_long(**args_dict) |
|
|
|
|
| if __name__ == "__main__": |
| parser = config.create_argument_parser() |
|
|
| args = parser.parse_args() |
|
|
| updated_args = config.update_group_args_with_defaults(args) |
|
|
| namespace_args = argparse.Namespace(**updated_args) |
| main(namespace_args) |
|
|