Spaces:
Build error
Build error
mrq
modified logic to determine valid voice folders, also allows subdirs within the folder (for example: ./voices/SH/james/ will be named SH/james)
faa8da1 | import os | |
| if 'XDG_CACHE_HOME' not in os.environ: | |
| os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/')) | |
| if 'TORTOISE_MODELS_DIR' not in os.environ: | |
| os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/')) | |
| if 'TRANSFORMERS_CACHE' not in os.environ: | |
| os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/')) | |
| import argparse | |
| import time | |
| import math | |
| import json | |
| import base64 | |
| import re | |
| import urllib.request | |
| import signal | |
| import gc | |
| import subprocess | |
| import psutil | |
| import yaml | |
| import hashlib | |
| import string | |
| import random | |
| from tqdm import tqdm | |
| import torch | |
| import torchaudio | |
| import music_tag | |
| import gradio as gr | |
| import gradio.utils | |
| import pandas as pd | |
| from glob import glob | |
| from datetime import datetime | |
| from datetime import timedelta | |
| from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate | |
| from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices | |
| from tortoise.utils.text import split_and_recombine_text | |
| from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc | |
| from whisper.normalizers.english import EnglishTextNormalizer | |
| from whisper.normalizers.basic import BasicTextNormalizer | |
| from whisper.tokenizer import LANGUAGES | |
| MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" | |
| WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] | |
| WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] | |
| WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] | |
| VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] | |
| TTSES = ['tortoise'] | |
| INFERENCING = False | |
| GENERATE_SETTINGS_ARGS = None | |
| LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"} | |
| LEARNING_RATE_SCHEDULE = [ 2, 4, 9, 18, 25, 33, 50 ] | |
| RESAMPLERS = {} | |
| MIN_TRAINING_DURATION = 0.6 | |
| MAX_TRAINING_DURATION = 11.6097505669 | |
| VALLE_ENABLED = False | |
| try: | |
| from vall_e.emb.qnt import encode as valle_quantize | |
| from vall_e.emb.g2p import encode as valle_phonemize | |
| from vall_e.inference import TTS as VALLE_TTS | |
| import soundfile | |
| VALLE_ENABLED = True | |
| except Exception as e: | |
| pass | |
| if VALLE_ENABLED: | |
| TTSES.append('vall-e') | |
| args = None | |
| tts = None | |
| tts_loading = False | |
| webui = None | |
| voicefixer = None | |
| whisper_model = None | |
| whisper_vad = None | |
| whisper_diarize = None | |
| whisper_align_model = None | |
| training_state = None | |
| current_voice = None | |
| def resample( waveform, input_rate, output_rate=44100 ): | |
| # mono-ize | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if input_rate == output_rate: | |
| return waveform, output_rate | |
| key = f'{input_rate}:{output_rate}' | |
| if not key in RESAMPLERS: | |
| RESAMPLERS[key] = torchaudio.transforms.Resample( | |
| input_rate, | |
| output_rate, | |
| lowpass_filter_width=16, | |
| rolloff=0.85, | |
| resampling_method="kaiser_window", | |
| beta=8.555504641634386, | |
| ) | |
| return RESAMPLERS[key]( waveform ), output_rate | |
| def generate(**kwargs): | |
| if args.tts_backend == "tortoise": | |
| return generate_tortoise(**kwargs) | |
| if args.tts_backend == "vall-e": | |
| return generate_valle(**kwargs) | |
| def generate_valle(**kwargs): | |
| parameters = {} | |
| parameters.update(kwargs) | |
| voice = parameters['voice'] | |
| progress = parameters['progress'] if 'progress' in parameters else None | |
| if parameters['seed'] == 0: | |
| parameters['seed'] = None | |
| usedSeed = parameters['seed'] | |
| global args | |
| global tts | |
| unload_whisper() | |
| unload_voicefixer() | |
| if not tts: | |
| # should check if it's loading or unloaded, and load it if it's unloaded | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| if progress is not None: | |
| progress(0, "Initializing TTS...") | |
| load_tts() | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| do_gc() | |
| voice_samples = None | |
| conditioning_latents = None | |
| sample_voice = None | |
| voice_cache = {} | |
| def fetch_voice( voice ): | |
| voice_dir = f'./training/{voice}/audio/' | |
| if not os.path.isdir(voice_dir): | |
| voice_dir = f'./voices/{voice}/' | |
| files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] | |
| # return files | |
| return random.choice(files) | |
| def get_settings( override=None ): | |
| settings = { | |
| 'ar_temp': float(parameters['temperature']), | |
| 'nar_temp': float(parameters['temperature']), | |
| 'max_ar_samples': parameters['num_autoregressive_samples'], | |
| } | |
| # could be better to just do a ternary on everything above, but i am not a professional | |
| selected_voice = voice | |
| if override is not None: | |
| if 'voice' in override: | |
| selected_voice = override['voice'] | |
| for k in override: | |
| if k not in settings: | |
| continue | |
| settings[k] = override[k] | |
| settings['reference'] = fetch_voice(voice=selected_voice) | |
| return settings | |
| if not parameters['delimiter']: | |
| parameters['delimiter'] = "\n" | |
| elif parameters['delimiter'] == "\\n": | |
| parameters['delimiter'] = "\n" | |
| if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']: | |
| texts = parameters['text'].split(parameters['delimiter']) | |
| else: | |
| texts = split_and_recombine_text(parameters['text']) | |
| full_start_time = time.time() | |
| outdir = f"{args.results_folder}/{voice}/" | |
| os.makedirs(outdir, exist_ok=True) | |
| audio_cache = {} | |
| volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None | |
| idx = 0 | |
| idx_cache = {} | |
| for i, file in enumerate(os.listdir(outdir)): | |
| filename = os.path.basename(file) | |
| extension = os.path.splitext(filename)[1] | |
| if extension != ".json" and extension != ".wav": | |
| continue | |
| match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) | |
| if match and len(match) > 0: | |
| key = int(match[0]) | |
| idx_cache[key] = True | |
| if len(idx_cache) > 0: | |
| keys = sorted(list(idx_cache.keys())) | |
| idx = keys[-1] + 1 | |
| idx = pad(idx, 4) | |
| def get_name(line=0, candidate=0, combined=False): | |
| name = f"{idx}" | |
| if combined: | |
| name = f"{name}_combined" | |
| elif len(texts) > 1: | |
| name = f"{name}_{line}" | |
| if parameters['candidates'] > 1: | |
| name = f"{name}_{candidate}" | |
| return name | |
| def get_info( voice, settings = None, latents = True ): | |
| info = {} | |
| info.update(parameters) | |
| info['time'] = time.time()-full_start_time | |
| info['datetime'] = datetime.now().isoformat() | |
| info['progress'] = None | |
| del info['progress'] | |
| if info['delimiter'] == "\n": | |
| info['delimiter'] = "\\n" | |
| if settings is not None: | |
| for k in settings: | |
| if k in info: | |
| info[k] = settings[k] | |
| return info | |
| INFERENCING = True | |
| for line, cut_text in enumerate(texts): | |
| progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' | |
| print(f"{progress.msg_prefix} Generating line: {cut_text}") | |
| start_time = time.time() | |
| # do setting editing | |
| match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) | |
| override = None | |
| if match and len(match) > 0: | |
| match = match[0] | |
| try: | |
| override = json.loads(match[0]) | |
| cut_text = match[1].strip() | |
| except Exception as e: | |
| raise Exception("Prompt settings editing requested, but received invalid JSON") | |
| settings = get_settings( override=override ) | |
| reference = settings['reference'] | |
| settings.pop("reference") | |
| gen = tts.inference(cut_text, reference, **settings ) | |
| run_time = time.time()-start_time | |
| print(f"Generating line took {run_time} seconds") | |
| if not isinstance(gen, list): | |
| gen = [gen] | |
| for j, g in enumerate(gen): | |
| wav, sr = g | |
| name = get_name(line=line, candidate=j) | |
| settings['text'] = cut_text | |
| settings['time'] = run_time | |
| settings['datetime'] = datetime.now().isoformat() | |
| # save here in case some error happens mid-batch | |
| #torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr) | |
| soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr) | |
| wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav') | |
| audio_cache[name] = { | |
| 'audio': wav, | |
| 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) | |
| } | |
| del gen | |
| do_gc() | |
| INFERENCING = False | |
| for k in audio_cache: | |
| audio = audio_cache[k]['audio'] | |
| audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) | |
| if volume_adjust is not None: | |
| audio = volume_adjust(audio) | |
| audio_cache[k]['audio'] = audio | |
| torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) | |
| output_voices = [] | |
| for candidate in range(parameters['candidates']): | |
| if len(texts) > 1: | |
| audio_clips = [] | |
| for line in range(len(texts)): | |
| name = get_name(line=line, candidate=candidate) | |
| audio = audio_cache[name]['audio'] | |
| audio_clips.append(audio) | |
| name = get_name(candidate=candidate, combined=True) | |
| audio = torch.cat(audio_clips, dim=-1) | |
| torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) | |
| audio = audio.squeeze(0).cpu() | |
| audio_cache[name] = { | |
| 'audio': audio, | |
| 'settings': get_info(voice=voice), | |
| 'output': True | |
| } | |
| else: | |
| name = get_name(candidate=candidate) | |
| audio_cache[name]['output'] = True | |
| if args.voice_fixer: | |
| if not voicefixer: | |
| progress(0, "Loading voicefix...") | |
| load_voicefixer() | |
| try: | |
| fixed_cache = {} | |
| for name in progress.tqdm(audio_cache, desc="Running voicefix..."): | |
| del audio_cache[name]['audio'] | |
| if 'output' not in audio_cache[name] or not audio_cache[name]['output']: | |
| continue | |
| path = f'{outdir}/{voice}_{name}.wav' | |
| fixed = f'{outdir}/{voice}_{name}_fixed.wav' | |
| voicefixer.restore( | |
| input=path, | |
| output=fixed, | |
| cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, | |
| #mode=mode, | |
| ) | |
| fixed_cache[f'{name}_fixed'] = { | |
| 'settings': audio_cache[name]['settings'], | |
| 'output': True | |
| } | |
| audio_cache[name]['output'] = False | |
| for name in fixed_cache: | |
| audio_cache[name] = fixed_cache[name] | |
| except Exception as e: | |
| print(e) | |
| print("\nFailed to run Voicefixer") | |
| for name in audio_cache: | |
| if 'output' not in audio_cache[name] or not audio_cache[name]['output']: | |
| if args.prune_nonfinal_outputs: | |
| audio_cache[name]['pruned'] = True | |
| os.remove(f'{outdir}/{voice}_{name}.wav') | |
| continue | |
| output_voices.append(f'{outdir}/{voice}_{name}.wav') | |
| if not args.embed_output_metadata: | |
| with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) | |
| if args.embed_output_metadata: | |
| for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): | |
| if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: | |
| continue | |
| metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") | |
| metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) | |
| metadata.save() | |
| if sample_voice is not None: | |
| sample_voice = (tts.input_sample_rate, sample_voice.numpy()) | |
| info = get_info(voice=voice, latents=False) | |
| print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") | |
| info['seed'] = usedSeed | |
| if 'latents' in info: | |
| del info['latents'] | |
| os.makedirs('./config/', exist_ok=True) | |
| with open(f'./config/generate.json', 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(info, indent='\t') ) | |
| stats = [ | |
| [ parameters['seed'], "{:.3f}".format(info['time']) ] | |
| ] | |
| return ( | |
| sample_voice, | |
| output_voices, | |
| stats, | |
| ) | |
| def generate_tortoise(**kwargs): | |
| parameters = {} | |
| parameters.update(kwargs) | |
| voice = parameters['voice'] | |
| progress = parameters['progress'] if 'progress' in parameters else None | |
| if parameters['seed'] == 0: | |
| parameters['seed'] = None | |
| usedSeed = parameters['seed'] | |
| global args | |
| global tts | |
| unload_whisper() | |
| unload_voicefixer() | |
| if not tts: | |
| # should check if it's loading or unloaded, and load it if it's unloaded | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| load_tts() | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| do_gc() | |
| voice_samples = None | |
| conditioning_latents = None | |
| sample_voice = None | |
| voice_cache = {} | |
| def fetch_voice( voice ): | |
| cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}' | |
| if cache_key in voice_cache: | |
| return voice_cache[cache_key] | |
| print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}") | |
| sample_voice = None | |
| if voice == "microphone": | |
| if parameters['mic_audio'] is None: | |
| raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") | |
| voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None | |
| elif voice == "random": | |
| voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() | |
| else: | |
| if progress is not None: | |
| progress(0, desc=f"Loading voice: {voice}") | |
| voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) | |
| if voice_samples and len(voice_samples) > 0: | |
| if conditioning_latents is None: | |
| conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks']) | |
| sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() | |
| voice_samples = None | |
| voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice) | |
| return voice_cache[cache_key] | |
| def get_settings( override=None ): | |
| settings = { | |
| 'temperature': float(parameters['temperature']), | |
| 'top_p': float(parameters['top_p']), | |
| 'diffusion_temperature': float(parameters['diffusion_temperature']), | |
| 'length_penalty': float(parameters['length_penalty']), | |
| 'repetition_penalty': float(parameters['repetition_penalty']), | |
| 'cond_free_k': float(parameters['cond_free_k']), | |
| 'num_autoregressive_samples': parameters['num_autoregressive_samples'], | |
| 'sample_batch_size': args.sample_batch_size, | |
| 'diffusion_iterations': parameters['diffusion_iterations'], | |
| 'voice_samples': None, | |
| 'conditioning_latents': None, | |
| 'use_deterministic_seed': parameters['seed'], | |
| 'return_deterministic_state': True, | |
| 'k': parameters['candidates'], | |
| 'diffusion_sampler': parameters['diffusion_sampler'], | |
| 'breathing_room': parameters['breathing_room'], | |
| 'progress': parameters['progress'], | |
| 'half_p': "Half Precision" in parameters['experimentals'], | |
| 'cond_free': "Conditioning-Free" in parameters['experimentals'], | |
| 'cvvp_amount': parameters['cvvp_weight'], | |
| 'autoregressive_model': args.autoregressive_model, | |
| 'diffusion_model': args.diffusion_model, | |
| 'tokenizer_json': args.tokenizer_json, | |
| } | |
| # could be better to just do a ternary on everything above, but i am not a professional | |
| selected_voice = voice | |
| if override is not None: | |
| if 'voice' in override: | |
| selected_voice = override['voice'] | |
| for k in override: | |
| if k not in settings: | |
| continue | |
| settings[k] = override[k] | |
| if settings['autoregressive_model'] is not None: | |
| if settings['autoregressive_model'] == "auto": | |
| settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice) | |
| tts.load_autoregressive_model(settings['autoregressive_model']) | |
| if settings['diffusion_model'] is not None: | |
| if settings['diffusion_model'] == "auto": | |
| settings['diffusion_model'] = deduce_diffusion_model(selected_voice) | |
| tts.load_diffusion_model(settings['diffusion_model']) | |
| if settings['tokenizer_json'] is not None: | |
| tts.load_tokenizer_json(settings['tokenizer_json']) | |
| settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice) | |
| # clamp it down for the insane users who want this | |
| # it would be wiser to enforce the sample size to the batch size, but this is what the user wants | |
| settings['sample_batch_size'] = args.sample_batch_size | |
| if not settings['sample_batch_size']: | |
| settings['sample_batch_size'] = tts.autoregressive_batch_size | |
| if settings['num_autoregressive_samples'] < settings['sample_batch_size']: | |
| settings['sample_batch_size'] = settings['num_autoregressive_samples'] | |
| if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0: | |
| print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.") | |
| settings['cvvp_amount'] = 0 | |
| return settings | |
| if not parameters['delimiter']: | |
| parameters['delimiter'] = "\n" | |
| elif parameters['delimiter'] == "\\n": | |
| parameters['delimiter'] = "\n" | |
| if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']: | |
| texts = parameters['text'].split(parameters['delimiter']) | |
| else: | |
| texts = split_and_recombine_text(parameters['text']) | |
| full_start_time = time.time() | |
| outdir = f"{args.results_folder}/{voice}/" | |
| os.makedirs(outdir, exist_ok=True) | |
| audio_cache = {} | |
| volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None | |
| idx = 0 | |
| idx_cache = {} | |
| for i, file in enumerate(os.listdir(outdir)): | |
| filename = os.path.basename(file) | |
| extension = os.path.splitext(filename)[1] | |
| if extension != ".json" and extension != ".wav": | |
| continue | |
| match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) | |
| if match and len(match) > 0: | |
| key = int(match[0]) | |
| idx_cache[key] = True | |
| if len(idx_cache) > 0: | |
| keys = sorted(list(idx_cache.keys())) | |
| idx = keys[-1] + 1 | |
| idx = pad(idx, 4) | |
| def get_name(line=0, candidate=0, combined=False): | |
| name = f"{idx}" | |
| if combined: | |
| name = f"{name}_combined" | |
| elif len(texts) > 1: | |
| name = f"{name}_{line}" | |
| if parameters['candidates'] > 1: | |
| name = f"{name}_{candidate}" | |
| return name | |
| def get_info( voice, settings = None, latents = True ): | |
| info = {} | |
| info.update(parameters) | |
| info['time'] = time.time()-full_start_time | |
| info['datetime'] = datetime.now().isoformat() | |
| info['model'] = tts.autoregressive_model_path | |
| info['model_hash'] = tts.autoregressive_model_hash | |
| info['progress'] = None | |
| del info['progress'] | |
| if info['delimiter'] == "\n": | |
| info['delimiter'] = "\\n" | |
| if settings is not None: | |
| for k in settings: | |
| if k in info: | |
| info[k] = settings[k] | |
| if 'half_p' in settings and 'cond_free' in settings: | |
| info['experimentals'] = [] | |
| if settings['half_p']: | |
| info['experimentals'].append("Half Precision") | |
| if settings['cond_free']: | |
| info['experimentals'].append("Conditioning-Free") | |
| if latents and "latents" not in info: | |
| voice = info['voice'] | |
| model_hash = settings["model_hash"][:8] if settings is not None and "model_hash" in settings else tts.autoregressive_model_hash[:8] | |
| dir = f'{get_voice_dir()}/{voice}/' | |
| latents_path = f'{dir}/cond_latents_{model_hash}.pth' | |
| if voice == "random" or voice == "microphone": | |
| if latents and settings is not None and settings['conditioning_latents']: | |
| os.makedirs(dir, exist_ok=True) | |
| torch.save(conditioning_latents, latents_path) | |
| if latents_path and os.path.exists(latents_path): | |
| try: | |
| with open(latents_path, 'rb') as f: | |
| info['latents'] = base64.b64encode(f.read()).decode("ascii") | |
| except Exception as e: | |
| pass | |
| return info | |
| INFERENCING = True | |
| for line, cut_text in enumerate(texts): | |
| if should_phonemize(): | |
| cut_text = phonemizer( cut_text ) | |
| if parameters['emotion'] == "Custom": | |
| if parameters['prompt'] and parameters['prompt'].strip() != "": | |
| cut_text = f"[{parameters['prompt']},] {cut_text}" | |
| elif parameters['emotion'] != "None" and parameters['emotion']: | |
| cut_text = f"[I am really {parameters['emotion'].lower()},] {cut_text}" | |
| progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' | |
| print(f"{progress.msg_prefix} Generating line: {cut_text}") | |
| start_time = time.time() | |
| # do setting editing | |
| match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) | |
| override = None | |
| if match and len(match) > 0: | |
| match = match[0] | |
| try: | |
| override = json.loads(match[0]) | |
| cut_text = match[1].strip() | |
| except Exception as e: | |
| raise Exception("Prompt settings editing requested, but received invalid JSON") | |
| settings = get_settings( override=override ) | |
| gen, additionals = tts.tts(cut_text, **settings ) | |
| parameters['seed'] = additionals[0] | |
| run_time = time.time()-start_time | |
| print(f"Generating line took {run_time} seconds") | |
| if not isinstance(gen, list): | |
| gen = [gen] | |
| for j, g in enumerate(gen): | |
| audio = g.squeeze(0).cpu() | |
| name = get_name(line=line, candidate=j) | |
| settings['text'] = cut_text | |
| settings['time'] = run_time | |
| settings['datetime'] = datetime.now().isoformat() | |
| if args.tts_backend == "tortoise": | |
| settings['model'] = tts.autoregressive_model_path | |
| settings['model_hash'] = tts.autoregressive_model_hash | |
| audio_cache[name] = { | |
| 'audio': audio, | |
| 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) | |
| } | |
| # save here in case some error happens mid-batch | |
| torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) | |
| del gen | |
| do_gc() | |
| INFERENCING = False | |
| for k in audio_cache: | |
| audio = audio_cache[k]['audio'] | |
| audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) | |
| if volume_adjust is not None: | |
| audio = volume_adjust(audio) | |
| audio_cache[k]['audio'] = audio | |
| torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) | |
| output_voices = [] | |
| for candidate in range(parameters['candidates']): | |
| if len(texts) > 1: | |
| audio_clips = [] | |
| for line in range(len(texts)): | |
| name = get_name(line=line, candidate=candidate) | |
| audio = audio_cache[name]['audio'] | |
| audio_clips.append(audio) | |
| name = get_name(candidate=candidate, combined=True) | |
| audio = torch.cat(audio_clips, dim=-1) | |
| torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) | |
| audio = audio.squeeze(0).cpu() | |
| audio_cache[name] = { | |
| 'audio': audio, | |
| 'settings': get_info(voice=voice), | |
| 'output': True | |
| } | |
| else: | |
| name = get_name(candidate=candidate) | |
| audio_cache[name]['output'] = True | |
| if args.voice_fixer: | |
| if not voicefixer: | |
| progress(0, "Loading voicefix...") | |
| load_voicefixer() | |
| try: | |
| fixed_cache = {} | |
| for name in progress.tqdm(audio_cache, desc="Running voicefix..."): | |
| del audio_cache[name]['audio'] | |
| if 'output' not in audio_cache[name] or not audio_cache[name]['output']: | |
| continue | |
| path = f'{outdir}/{voice}_{name}.wav' | |
| fixed = f'{outdir}/{voice}_{name}_fixed.wav' | |
| voicefixer.restore( | |
| input=path, | |
| output=fixed, | |
| cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, | |
| #mode=mode, | |
| ) | |
| fixed_cache[f'{name}_fixed'] = { | |
| 'settings': audio_cache[name]['settings'], | |
| 'output': True | |
| } | |
| audio_cache[name]['output'] = False | |
| for name in fixed_cache: | |
| audio_cache[name] = fixed_cache[name] | |
| except Exception as e: | |
| print(e) | |
| print("\nFailed to run Voicefixer") | |
| for name in audio_cache: | |
| if 'output' not in audio_cache[name] or not audio_cache[name]['output']: | |
| if args.prune_nonfinal_outputs: | |
| audio_cache[name]['pruned'] = True | |
| os.remove(f'{outdir}/{voice}_{name}.wav') | |
| continue | |
| output_voices.append(f'{outdir}/{voice}_{name}.wav') | |
| if not args.embed_output_metadata: | |
| with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) | |
| if args.embed_output_metadata: | |
| for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): | |
| if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: | |
| continue | |
| metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") | |
| metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) | |
| metadata.save() | |
| if sample_voice is not None: | |
| sample_voice = (tts.input_sample_rate, sample_voice.numpy()) | |
| info = get_info(voice=voice, latents=False) | |
| print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") | |
| info['seed'] = usedSeed | |
| if 'latents' in info: | |
| del info['latents'] | |
| os.makedirs('./config/', exist_ok=True) | |
| with open(f'./config/generate.json', 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(info, indent='\t') ) | |
| stats = [ | |
| [ parameters['seed'], "{:.3f}".format(info['time']) ] | |
| ] | |
| return ( | |
| sample_voice, | |
| output_voices, | |
| stats, | |
| ) | |
| def cancel_generate(): | |
| if not INFERENCING: | |
| return | |
| import tortoise.api | |
| tortoise.api.STOP_SIGNAL = True | |
| def hash_file(path, algo="md5", buffer_size=0): | |
| hash = None | |
| if algo == "md5": | |
| hash = hashlib.md5() | |
| elif algo == "sha1": | |
| hash = hashlib.sha1() | |
| else: | |
| raise Exception(f'Unknown hash algorithm specified: {algo}') | |
| if not os.path.exists(path): | |
| raise Exception(f'Path not found: {path}') | |
| with open(path, 'rb') as f: | |
| if buffer_size > 0: | |
| while True: | |
| data = f.read(buffer_size) | |
| if not data: | |
| break | |
| hash.update(data) | |
| else: | |
| hash.update(f.read()) | |
| return "{0}".format(hash.hexdigest()) | |
| def update_baseline_for_latents_chunks( voice ): | |
| global current_voice | |
| current_voice = voice | |
| path = f'{get_voice_dir()}/{voice}/' | |
| if not os.path.isdir(path): | |
| return 1 | |
| dataset_file = f'./training/{voice}/train.txt' | |
| if os.path.exists(dataset_file): | |
| return 0 # 0 will leverage using the LJspeech dataset for computing latents | |
| files = os.listdir(path) | |
| total = 0 | |
| total_duration = 0 | |
| for file in files: | |
| if file[-4:] != ".wav": | |
| continue | |
| metadata = torchaudio.info(f'{path}/{file}') | |
| duration = metadata.num_frames / metadata.sample_rate | |
| total_duration += duration | |
| total = total + 1 | |
| # brain too fried to figure out a better way | |
| if args.autocalculate_voice_chunk_duration_size == 0: | |
| return int(total_duration / total) if total > 0 else 1 | |
| return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1 | |
| def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, progress=None): | |
| global tts | |
| global args | |
| unload_whisper() | |
| unload_voicefixer() | |
| if not tts: | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| load_tts() | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| if args.autoregressive_model == "auto": | |
| tts.load_autoregressive_model(deduce_autoregressive_model(voice)) | |
| if voice: | |
| load_from_dataset = voice_latents_chunks == 0 | |
| if load_from_dataset: | |
| dataset_path = f'./training/{voice}/train.txt' | |
| if not os.path.exists(dataset_path): | |
| load_from_dataset = False | |
| else: | |
| with open(dataset_path, 'r', encoding="utf-8") as f: | |
| lines = f.readlines() | |
| print("Leveraging dataset for computing latents") | |
| voice_samples = [] | |
| max_length = 0 | |
| for line in lines: | |
| filename = f'./training/{voice}/{line.split("|")[0]}' | |
| waveform = load_audio(filename, 22050) | |
| max_length = max(max_length, waveform.shape[-1]) | |
| voice_samples.append(waveform) | |
| for i in range(len(voice_samples)): | |
| voice_samples[i] = pad_or_truncate(voice_samples[i], max_length) | |
| voice_latents_chunks = len(voice_samples) | |
| if voice_latents_chunks == 0: | |
| print("Dataset is empty!") | |
| load_from_dataset = True | |
| if not load_from_dataset: | |
| voice_samples, _ = load_voice(voice, load_latents=False) | |
| if voice_samples is None: | |
| return | |
| conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, progress=progress) | |
| if len(conditioning_latents) == 4: | |
| conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) | |
| outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' | |
| torch.save(conditioning_latents, outfile) | |
| print(f'Saved voice latents: {outfile}') | |
| return conditioning_latents | |
| # superfluous, but it cleans up some things | |
| class TrainingState(): | |
| def __init__(self, config_path, keep_x_past_checkpoints=0, start=True): | |
| self.killed = False | |
| self.training_dir = os.path.dirname(config_path) | |
| with open(config_path, 'r') as file: | |
| self.yaml_config = yaml.safe_load(file) | |
| self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8")) | |
| self.dataset_path = f"{self.training_dir}/train.txt" | |
| with open(self.dataset_path, 'r', encoding="utf-8") as f: | |
| self.dataset_size = len(f.readlines()) | |
| self.batch_size = self.json_config["batch_size"] | |
| self.save_rate = self.json_config["save_rate"] | |
| self.epoch = 0 | |
| self.epochs = self.json_config["epochs"] | |
| self.it = 0 | |
| self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size ) | |
| self.step = 0 | |
| self.steps = int(self.its / self.dataset_size) | |
| self.checkpoint = 0 | |
| self.checkpoints = int((self.its - self.it) / self.save_rate) | |
| self.gpus = self.json_config['gpus'] | |
| self.buffer = [] | |
| self.open_state = False | |
| self.training_started = False | |
| self.info = {} | |
| self.it_rate = "" | |
| self.it_rates = 0 | |
| self.epoch_rate = "" | |
| self.eta = "?" | |
| self.eta_hhmmss = "?" | |
| self.nan_detected = False | |
| self.last_info_check_at = 0 | |
| self.statistics = { | |
| 'loss': [], | |
| 'lr': [], | |
| 'grad_norm': [], | |
| } | |
| self.losses = [] | |
| self.metrics = { | |
| 'step': "", | |
| 'rate': "", | |
| 'loss': "", | |
| } | |
| self.loss_milestones = [ 1.0, 0.15, 0.05 ] | |
| if args.tts_backend=="vall-e": | |
| self.valle_last_it = 0 | |
| self.valle_steps = 0 | |
| if keep_x_past_checkpoints > 0: | |
| self.cleanup_old(keep=keep_x_past_checkpoints) | |
| if start: | |
| self.spawn_process(config_path=config_path, gpus=self.gpus) | |
| def spawn_process(self, config_path, gpus=1): | |
| if args.tts_backend == "vall-e": | |
| self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"'] | |
| else: | |
| self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] | |
| print("Spawning process: ", " ".join(self.cmd)) | |
| self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) | |
| def parse_metrics(self, data): | |
| if isinstance(data, str): | |
| if line.find('Training Metrics:') >= 0: | |
| data = json.loads(line.split("Training Metrics:")[-1]) | |
| data['mode'] = "training" | |
| elif line.find('Validation Metrics:') >= 0: | |
| data = json.loads(line.split("Validation Metrics:")[-1]) | |
| data['mode'] = "validation" | |
| else: | |
| return | |
| self.info = data | |
| if 'epoch' in self.info: | |
| self.epoch = int(self.info['epoch']) | |
| if 'it' in self.info: | |
| self.it = int(self.info['it']) | |
| if 'step' in self.info: | |
| self.step = int(self.info['step']) | |
| if 'steps' in self.info: | |
| self.steps = int(self.info['steps']) | |
| if 'elapsed_time' in self.info: | |
| self.info['iteration_rate'] = self.info['elapsed_time'] | |
| del self.info['elapsed_time'] | |
| if 'iteration_rate' in self.info: | |
| it_rate = self.info['iteration_rate'] | |
| self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' | |
| self.it_rates += it_rate | |
| if self.it_rates > 0 and self.it * self.steps > 0: | |
| epoch_rate = self.it_rates / self.it * self.steps | |
| self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch' | |
| try: | |
| self.eta = (self.its - self.it) * (self.it_rates / self.it) | |
| eta = str(timedelta(seconds=int(self.eta))) | |
| self.eta_hhmmss = eta | |
| except Exception as e: | |
| self.eta_hhmmss = "?" | |
| pass | |
| self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] | |
| if self.epochs != self.its: | |
| self.metrics['step'].append(f"{self.it}/{self.its}") | |
| if self.steps > 1: | |
| self.metrics['step'].append(f"{self.step}/{self.steps}") | |
| self.metrics['step'] = ", ".join(self.metrics['step']) | |
| if args.tts_backend == "tortoise": | |
| epoch = self.epoch + (self.step / self.steps) | |
| else: | |
| epoch = self.info['epoch'] if 'epoch' in self.info else self.it | |
| if self.it > 0: | |
| # probably can double for-loop but whatever | |
| keys = { | |
| 'lrs': ['lr'], | |
| 'losses': ['loss_text_ce', 'loss_mel_ce'], | |
| 'accuracies': [], | |
| 'grad_norms': [], | |
| } | |
| if args.tts_backend == "vall-e": | |
| keys['lrs'] = [ | |
| 'ar.lr', 'nar.lr', | |
| 'ar-half.lr', 'nar-half.lr', | |
| 'ar-quarter.lr', 'nar-quarter.lr', | |
| ] | |
| keys['losses'] = [ | |
| 'ar.loss', 'nar.loss', 'ar+nar.loss', | |
| 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss', | |
| 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss', | |
| # 'ar.loss.nll', 'nar.loss.nll', | |
| # 'ar-half.loss.nll', 'nar-half.loss.nll', | |
| # 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', | |
| ] | |
| keys['accuracies'] = [ | |
| 'ar.loss.acc', 'nar.loss.acc', | |
| 'ar-half.loss.acc', 'nar-half.loss.acc', | |
| 'ar-quarter.loss.acc', 'nar-quarter.loss.acc', | |
| ] | |
| keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm'] | |
| for k in keys['lrs']: | |
| if k not in self.info: | |
| continue | |
| self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) | |
| for k in keys['accuracies']: | |
| if k not in self.info: | |
| continue | |
| self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) | |
| for k in keys['losses']: | |
| if k not in self.info: | |
| continue | |
| prefix = "" | |
| if "mode" in self.info and self.info["mode"] == "validation": | |
| prefix = f'{self.info["name"] if "name" in self.info else "val"}_' | |
| self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' }) | |
| self.losses.append( self.statistics['loss'][-1] ) | |
| for k in keys['grad_norms']: | |
| if k not in self.info: | |
| continue | |
| self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) | |
| return data | |
| def get_status(self): | |
| message = None | |
| self.metrics['rate'] = [] | |
| if self.epoch_rate: | |
| self.metrics['rate'].append(self.epoch_rate) | |
| if self.it_rate and self.epoch_rate[:-7] != self.it_rate[:-4]: | |
| self.metrics['rate'].append(self.it_rate) | |
| self.metrics['rate'] = ", ".join(self.metrics['rate']) | |
| eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?" | |
| self.metrics['loss'] = [] | |
| if 'lr' in self.info: | |
| self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}') | |
| if len(self.losses) > 0: | |
| self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') | |
| if False and len(self.losses) >= 2: | |
| deriv = 0 | |
| accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it | |
| loss_value = self.losses[-1]["value"] | |
| for i in range(accum_length): | |
| d1_loss = self.losses[accum_length-i-1]["value"] | |
| d2_loss = self.losses[accum_length-i-2]["value"] | |
| dloss = (d2_loss - d1_loss) | |
| d1_step = self.losses[accum_length-i-1]["it"] | |
| d2_step = self.losses[accum_length-i-2]["it"] | |
| dstep = (d2_step - d1_step) | |
| if dstep == 0: | |
| continue | |
| inst_deriv = dloss / dstep | |
| deriv += inst_deriv | |
| deriv = deriv / accum_length | |
| print("Deriv: ", deriv) | |
| if deriv != 0: # dloss < 0: | |
| next_milestone = None | |
| for milestone in self.loss_milestones: | |
| if loss_value > milestone: | |
| next_milestone = milestone | |
| break | |
| print(f"Loss value: {loss_value} | Next milestone: {next_milestone} | Distance: {loss_value - next_milestone}") | |
| if next_milestone: | |
| # tfw can do simple calculus but not basic algebra in my head | |
| est_its = (next_milestone - loss_value) / deriv * 100 | |
| print(f"Estimated: {est_its}") | |
| if est_its >= 0: | |
| self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') | |
| else: | |
| est_loss = inst_deriv * (self.its - self.it) + loss_value | |
| if est_loss >= 0: | |
| self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}') | |
| self.metrics['loss'] = ", ".join(self.metrics['loss']) | |
| message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]" | |
| if self.nan_detected: | |
| message = f"[!NaN DETECTED! {self.nan_detected}] {message}" | |
| return message | |
| def load_statistics(self, update=False): | |
| if not os.path.isdir(self.training_dir): | |
| return | |
| if args.tts_backend == "tortoise": | |
| logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ]) | |
| else: | |
| logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ]) | |
| if update: | |
| logs = [logs[-1]] | |
| infos = {} | |
| highest_step = self.last_info_check_at | |
| if not update: | |
| self.statistics['loss'] = [] | |
| self.statistics['lr'] = [] | |
| self.statistics['grad_norm'] = [] | |
| self.it_rates = 0 | |
| unq = {} | |
| averager = None | |
| prev_state = 0 | |
| for log in logs: | |
| with open(log, 'r', encoding="utf-8") as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line[-1] == ".": | |
| line = line[:-1] | |
| if line.find('Training Metrics:') >= 0: | |
| split = line.split("Training Metrics:")[-1] | |
| data = json.loads(split) | |
| name = "train" | |
| mode = "training" | |
| prev_state = 0 | |
| elif line.find('Validation Metrics:') >= 0: | |
| data = json.loads(line.split("Validation Metrics:")[-1]) | |
| if "it" not in data: | |
| data['it'] = it | |
| if "epoch" not in data: | |
| data['epoch'] = epoch | |
| # name = data['name'] if 'name' in data else "val" | |
| mode = "validation" | |
| if prev_state == 0: | |
| name = "subtrain" | |
| else: | |
| name = "val" | |
| prev_state += 1 | |
| else: | |
| continue | |
| if "it" not in data: | |
| continue | |
| it = data['it'] | |
| epoch = data['epoch'] | |
| if args.tts_backend == "vall-e": | |
| if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode: | |
| averager = { | |
| 'key': f'{it}_{name}', | |
| 'name': name, | |
| 'mode': mode, | |
| "metrics": {} | |
| } | |
| for k in data: | |
| if data[k] is None: | |
| continue | |
| averager['metrics'][k] = [ data[k] ] | |
| else: | |
| for k in data: | |
| if data[k] is None: | |
| continue | |
| averager['metrics'][k].append( data[k] ) | |
| unq[f'{it}_{mode}_{name}'] = averager | |
| else: | |
| unq[f'{it}_{mode}_{name}'] = data | |
| if update and it <= self.last_info_check_at: | |
| continue | |
| blacklist = [ "batch", "eval" ] | |
| for it in unq: | |
| if args.tts_backend == "vall-e": | |
| stats = unq[it] | |
| data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist } | |
| data['name'] = stats['name'] | |
| data['mode'] = stats['mode'] | |
| data['steps'] = len(stats['metrics']['it']) | |
| else: | |
| data = unq[it] | |
| self.parse_metrics(data) | |
| self.last_info_check_at = highest_step | |
| def cleanup_old(self, keep=2): | |
| if keep <= 0: | |
| return | |
| if args.tts_backend == "vall-e": | |
| return | |
| if not os.path.isdir(f'{self.training_dir}/finetune/'): | |
| return | |
| models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ]) | |
| states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ]) | |
| remove_models = models[:-keep] | |
| remove_states = states[:-keep] | |
| for d in remove_models: | |
| path = f'{self.training_dir}/finetune/models/{d}_gpt.pth' | |
| print("Removing", path) | |
| os.remove(path) | |
| for d in remove_states: | |
| path = f'{self.training_dir}/finetune/training_state/{d}.state' | |
| print("Removing", path) | |
| os.remove(path) | |
| def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ): | |
| self.buffer.append(f'{line}') | |
| data = None | |
| percent = 0 | |
| message = None | |
| should_return = False | |
| MESSAGE_START = 'Start training from epoch' | |
| MESSAGE_FINSIHED = 'Finished training' | |
| MESSAGE_SAVING = 'Saving models and training states.' | |
| MESSAGE_METRICS_TRAINING = 'Training Metrics:' | |
| MESSAGE_METRICS_VALIDATION = 'Validation Metrics:' | |
| if line.find(MESSAGE_FINSIHED) >= 0: | |
| self.killed = True | |
| # rip out iteration info | |
| elif not self.training_started: | |
| if line.find(MESSAGE_START) >= 0: | |
| self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations | |
| match = re.findall(r'epoch: ([\d,]+)', line) | |
| if match and len(match) > 0: | |
| self.epoch = int(match[0].replace(",", "")) | |
| match = re.findall(r'iter: ([\d,]+)', line) | |
| if match and len(match) > 0: | |
| self.it = int(match[0].replace(",", "")) | |
| self.checkpoints = int((self.its - self.it) / self.save_rate) | |
| self.load_statistics() | |
| should_return = True | |
| else: | |
| if line.find(MESSAGE_SAVING) >= 0: | |
| self.checkpoint += 1 | |
| message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..." | |
| percent = self.checkpoint / self.checkpoints | |
| self.cleanup_old(keep=keep_x_past_checkpoints) | |
| elif line.find(MESSAGE_METRICS_TRAINING) >= 0: | |
| data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1]) | |
| data['mode'] = "training" | |
| elif line.find(MESSAGE_METRICS_VALIDATION) >= 0: | |
| data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1]) | |
| data['mode'] = "validation" | |
| if data is not None: | |
| if ': nan' in line and not self.nan_detected: | |
| self.nan_detected = self.it | |
| self.parse_metrics( data ) | |
| message = self.get_status() | |
| if message: | |
| percent = self.it / float(self.its) # self.epoch / float(self.epochs) | |
| if progress is not None: | |
| progress(percent, message) | |
| self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}') | |
| should_return = True | |
| if verbose and not self.training_started: | |
| should_return = True | |
| self.buffer = self.buffer[-buffer_size:] | |
| result = None | |
| if should_return: | |
| result = "".join(self.buffer) if not self.training_started else message | |
| return ( | |
| result, | |
| percent, | |
| message, | |
| ) | |
| try: | |
| import altair as alt | |
| alt.data_transformers.enable('default', max_rows=None) | |
| except Exception as e: | |
| print(e) | |
| pass | |
| def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)): | |
| global training_state | |
| if training_state and training_state.process: | |
| return "Training already in progress" | |
| # ensure we have the dvae.pth | |
| if args.tts_backend == "tortoise": | |
| get_model_path('dvae.pth') | |
| # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process | |
| torch.multiprocessing.freeze_support() | |
| unload_tts() | |
| unload_whisper() | |
| unload_voicefixer() | |
| training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints) | |
| for line in iter(training_state.process.stdout.readline, ""): | |
| if training_state.killed: | |
| return | |
| result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress ) | |
| print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") | |
| if result: | |
| yield result | |
| if progress is not None and message: | |
| progress(percent, message) | |
| if training_state: | |
| training_state.process.stdout.close() | |
| return_code = training_state.process.wait() | |
| training_state = None | |
| def update_training_dataplot(x_lim=None, y_lim=None, config_path=None): | |
| global training_state | |
| losses = None | |
| lrs = None | |
| grad_norms = None | |
| x_lim = [ 0, x_lim ] | |
| y_lim = [ 0, y_lim ] | |
| if not training_state: | |
| if config_path: | |
| training_state = TrainingState(config_path=config_path, start=False) | |
| training_state.load_statistics() | |
| message = training_state.get_status() | |
| if training_state: | |
| if not x_lim[-1]: | |
| x_lim[-1] = training_state.epochs | |
| if not y_lim[-1]: | |
| y_lim = None | |
| if len(training_state.statistics['loss']) > 0: | |
| losses = gr.LinePlot.update( | |
| value = pd.DataFrame(training_state.statistics['loss']), | |
| x_lim=x_lim, y_lim=y_lim, | |
| x="epoch", y="value", | |
| title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], | |
| width=500, height=350 | |
| ) | |
| if len(training_state.statistics['lr']) > 0: | |
| lrs = gr.LinePlot.update( | |
| value = pd.DataFrame(training_state.statistics['lr']), | |
| x_lim=x_lim, y_lim=y_lim, | |
| x="epoch", y="value", | |
| title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], | |
| width=500, height=350 | |
| ) | |
| if len(training_state.statistics['grad_norm']) > 0: | |
| grad_norms = gr.LinePlot.update( | |
| value = pd.DataFrame(training_state.statistics['grad_norm']), | |
| x_lim=x_lim, y_lim=y_lim, | |
| x="epoch", y="value", | |
| title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], | |
| width=500, height=350 | |
| ) | |
| if config_path: | |
| del training_state | |
| training_state = None | |
| return (losses, lrs, grad_norms) | |
| def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): | |
| global training_state | |
| if not training_state or not training_state.process: | |
| return "Training not in progress" | |
| for line in iter(training_state.process.stdout.readline, ""): | |
| result, percent, message = training_state.parse( line=line, verbose=verbose, progress=progress ) | |
| print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") | |
| if result: | |
| yield result | |
| if progress is not None and message: | |
| progress(percent, message) | |
| def stop_training(): | |
| global training_state | |
| if training_state is None: | |
| return "No training in progress" | |
| print("Killing training process...") | |
| training_state.killed = True | |
| children = [] | |
| if args.tts_backend == "tortoise": | |
| # wrapped in a try/catch in case for some reason this fails outside of Linux | |
| try: | |
| children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] | |
| except Exception as e: | |
| pass | |
| training_state.process.stdout.close() | |
| training_state.process.terminate() | |
| training_state.process.kill() | |
| elif args.tts_backend == "vall-e": | |
| print(training_state.process.communicate(input='quit')[0]) | |
| return_code = training_state.process.wait() | |
| for p in children: | |
| os.kill( p['pid'], signal.SIGKILL ) | |
| training_state = None | |
| print("Killed training process.") | |
| return f"Training cancelled: {return_code}" | |
| def get_halfp_model_path(): | |
| autoregressive_model_path = get_model_path('autoregressive.pth') | |
| return autoregressive_model_path.replace(".pth", "_half.pth") | |
| def convert_to_halfp(): | |
| autoregressive_model_path = get_model_path('autoregressive.pth') | |
| print(f'Converting model to half precision: {autoregressive_model_path}') | |
| model = torch.load(autoregressive_model_path) | |
| for k in model: | |
| model[k] = model[k].half() | |
| outfile = get_halfp_model_path() | |
| torch.save(model, outfile) | |
| print(f'Converted model to half precision: {outfile}') | |
| # collapses short segments into the previous segment | |
| def whisper_sanitize( results ): | |
| sanitized = json.loads(json.dumps(results)) | |
| sanitized['segments'] = [] | |
| for segment in results['segments']: | |
| length = segment['end'] - segment['start'] | |
| if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0: | |
| sanitized['segments'].append(segment) | |
| continue | |
| last_segment = sanitized['segments'][-1] | |
| # segment already asimilitated it, somehow | |
| if last_segment['end'] >= segment['end']: | |
| continue | |
| """ | |
| # segment already asimilitated it, somehow | |
| if last_segment['text'].endswith(segment['text']): | |
| continue | |
| """ | |
| last_segment['text'] += segment['text'] | |
| last_segment['end'] = segment['end'] | |
| for i in range(len(sanitized['segments'])): | |
| sanitized['segments'][i]['id'] = i | |
| return sanitized | |
| def whisper_transcribe( file, language=None ): | |
| # shouldn't happen, but it's for safety | |
| global whisper_model | |
| global whisper_vad | |
| global whisper_diarize | |
| global whisper_align_model | |
| if not whisper_model: | |
| load_whisper_model(language=language) | |
| if args.whisper_backend == "openai/whisper": | |
| if not language: | |
| language = None | |
| return whisper_model.transcribe(file, language=language) | |
| if args.whisper_backend == "lightmare/whispercpp": | |
| res = whisper_model.transcribe(file) | |
| segments = whisper_model.extract_text_and_timestamps( res ) | |
| result = { | |
| 'text': [], | |
| 'segments': [] | |
| } | |
| for segment in segments: | |
| reparsed = { | |
| 'start': segment[0] / 100.0, | |
| 'end': segment[1] / 100.0, | |
| 'text': segment[2], | |
| 'id': len(result['segments']) | |
| } | |
| result['text'].append( segment[2] ) | |
| result['segments'].append(reparsed) | |
| result['text'] = " ".join(result['text']) | |
| return result | |
| if args.whisper_backend == "m-bain/whisperx": | |
| import whisperx | |
| from whisperx.diarize import assign_word_speakers | |
| device = "cuda" if get_device_name() == "cuda" else "cpu" | |
| if whisper_vad: | |
| # omits a considerable amount of the end | |
| """ | |
| if args.whisper_batchsize > 1: | |
| result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") | |
| else: | |
| result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) | |
| """ | |
| result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) | |
| else: | |
| result = whisper_model.transcribe(file) | |
| align_model, metadata = whisper_align_model | |
| result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device) | |
| if whisper_diarize: | |
| diarize_segments = whisper_diarize(file) | |
| diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True)) | |
| diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) | |
| diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) | |
| # assumes each utterance is single speaker (needs fix) | |
| result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True) | |
| result_aligned["segments"] = result_segments | |
| result_aligned["word_segments"] = word_segments | |
| for i in range(len(result_aligned['segments'])): | |
| del result_aligned['segments'][i]['word-segments'] | |
| del result_aligned['segments'][i]['char-segments'] | |
| result['segments'] = result_aligned['segments'] | |
| result['text'] = [] | |
| for segment in result['segments']: | |
| segment['id'] = len(result['text']) | |
| result['text'].append(segment['text'].strip()) | |
| result['text'] = " ".join(result['text']) | |
| return result | |
| def validate_waveform( waveform, sample_rate, min_only=False ): | |
| if not torch.any(waveform < 0): | |
| return "Waveform is empty" | |
| num_channels, num_frames = waveform.shape | |
| duration = num_frames / sample_rate | |
| if duration < MIN_TRAINING_DURATION: | |
| return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION) | |
| if not min_only: | |
| if duration > MAX_TRAINING_DURATION: | |
| return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration) | |
| return | |
| def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ): | |
| unload_tts() | |
| global whisper_model | |
| if whisper_model is None: | |
| load_whisper_model(language=language) | |
| results = {} | |
| files = get_voice(voice, load_latents=False) | |
| indir = f'./training/{voice}/' | |
| infile = f'{indir}/whisper.json' | |
| os.makedirs(f'{indir}/audio/', exist_ok=True) | |
| TARGET_SAMPLE_RATE = 22050 | |
| if args.tts_backend == "vall-e": | |
| TARGET_SAMPLE_RATE = 24000 | |
| if tts: | |
| TARGET_SAMPLE_RATE = tts.input_sample_rate | |
| if os.path.exists(infile): | |
| results = json.load(open(infile, 'r', encoding="utf-8")) | |
| for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): | |
| basename = os.path.basename(file) | |
| if basename in results and skip_existings: | |
| print(f"Skipping already parsed file: {basename}") | |
| continue | |
| try: | |
| result = whisper_transcribe(file, language=language) | |
| except Exception as e: | |
| print("Failed to transcribe:", file) | |
| continue | |
| results[basename] = result | |
| waveform, sample_rate = torchaudio.load(file) | |
| # resample to the input rate, since it'll get resampled for training anyways | |
| # this should also "help" increase throughput a bit when filling the dataloaders | |
| waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) | |
| if waveform.shape[0] == 2: | |
| waveform = waveform[:1] | |
| torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) | |
| with open(infile, 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(results, indent='\t')) | |
| do_gc() | |
| modified = False | |
| for basename in results: | |
| try: | |
| sanitized = whisper_sanitize(results[basename]) | |
| if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']): | |
| results[basename] = sanitized | |
| modified = True | |
| print("Segments sanizited: ", basename) | |
| except Exception as e: | |
| print("Failed to sanitize:", basename, e) | |
| pass | |
| if modified: | |
| os.rename(infile, infile.replace(".json", ".unsanitized.json")) | |
| with open(infile, 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(results, indent='\t')) | |
| return f"Processed dataset to: {indir}" | |
| def slice_waveform( waveform, sample_rate, start, end, trim ): | |
| start = int(start * sample_rate) | |
| end = int(end * sample_rate) | |
| if start < 0: | |
| start = 0 | |
| if end >= waveform.shape[-1]: | |
| end = waveform.shape[-1] - 1 | |
| sliced = waveform[:, start:end] | |
| error = validate_waveform( sliced, sample_rate, min_only=True ) | |
| if trim and not error: | |
| sliced = torchaudio.functional.vad( sliced, sample_rate ) | |
| return sliced, error | |
| def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None, progress=gr.Progress() ): | |
| indir = f'./training/{voice}/' | |
| infile = f'{indir}/whisper.json' | |
| messages = [] | |
| if not os.path.exists(infile): | |
| message = f"Missing dataset: {infile}" | |
| print(message) | |
| return message | |
| if results is None: | |
| results = json.load(open(infile, 'r', encoding="utf-8")) | |
| TARGET_SAMPLE_RATE = 22050 | |
| if args.tts_backend == "vall-e": | |
| TARGET_SAMPLE_RATE = 24000 | |
| if tts: | |
| TARGET_SAMPLE_RATE = tts.input_sample_rate | |
| files = 0 | |
| segments = 0 | |
| for filename in results: | |
| path = f'./voices/{voice}/{filename}' | |
| if not os.path.exists(path): | |
| path = f'./training/{voice}/{filename}' | |
| if not os.path.exists(path): | |
| message = f"Missing source audio: {filename}" | |
| print(message) | |
| messages.append(message) | |
| continue | |
| files += 1 | |
| result = results[filename] | |
| waveform, sample_rate = torchaudio.load(path) | |
| num_channels, num_frames = waveform.shape | |
| duration = num_frames / sample_rate | |
| for segment in result['segments']: | |
| file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") | |
| sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence ) | |
| if error: | |
| message = f"{error}, skipping... {file}" | |
| print(message) | |
| messages.append(message) | |
| continue | |
| sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE ) | |
| if waveform.shape[0] == 2: | |
| waveform = waveform[:1] | |
| torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16) | |
| segments +=1 | |
| messages.append(f"Sliced segments: {files} => {segments}.") | |
| return "\n".join(messages) | |
| # takes an LJSpeech-dataset-formatted .txt file and phonemize it | |
| def phonemize_txt_file( path ): | |
| with open(path, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| reparsed = [] | |
| with open(path.replace(".txt", ".phn.txt"), 'a', encoding='utf-8') as f: | |
| for line in enumerate_progress(lines, desc='Phonemizing...'): | |
| split = line.split("|") | |
| audio = split[0] | |
| text = split[2] | |
| phonemes = phonemizer( text ) | |
| reparsed.append(f'{audio}|{phonemes}') | |
| f.write(f'\n{audio}|{phonemes}') | |
| joined = "\n".join(reparsed) | |
| with open(path.replace(".txt", ".phn.txt"), 'w', encoding='utf-8') as f: | |
| f.write(joined) | |
| return joined | |
| # takes an LJSpeech-dataset-formatted .txt (and phonemized .phn.txt from the above) and creates a JSON that should slot in as whisper.json | |
| def create_dataset_json( path ): | |
| with open(path, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| phonemes = None | |
| phn_path = path.replace(".txt", ".phn.txt") | |
| if os.path.exists(phn_path): | |
| with open(phn_path, 'r', encoding='utf-8') as f: | |
| phonemes = f.readlines() | |
| data = {} | |
| for line in lines: | |
| split = line.split("|") | |
| audio = split[0] | |
| text = split[1] | |
| data[audio] = { | |
| 'text': text.strip() | |
| } | |
| for line in phonemes: | |
| split = line.split("|") | |
| audio = split[0] | |
| text = split[1] | |
| data[audio]['phonemes'] = text.strip() | |
| with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f: | |
| f.write(json.dumps(data, indent="\t")) | |
| def phonemizer( text, language="en-us" ): | |
| from phonemizer import phonemize | |
| if language == "en": | |
| language = "en-us" | |
| return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend ) | |
| def should_phonemize(): | |
| should = args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json" | |
| if should: | |
| try: | |
| from phonemizer import phonemize | |
| except Exception as e: | |
| return False | |
| return should | |
| def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ): | |
| indir = f'./training/{voice}/' | |
| infile = f'{indir}/whisper.json' | |
| if not os.path.exists(infile): | |
| message = f"Missing dataset: {infile}" | |
| print(message) | |
| return message | |
| results = json.load(open(infile, 'r', encoding="utf-8")) | |
| errored = 0 | |
| messages = [] | |
| normalize = True | |
| phonemize = should_phonemize() | |
| lines = { 'training': [], 'validation': [] } | |
| segments = {} | |
| # I'm not sure how the VALL-E implementation decides what's validation and what's not | |
| if args.tts_backend == "vall-e": | |
| text_length = 0 | |
| audio_length = 0 | |
| for filename in enumerate_progress(results, desc="Parsing results", progress=progress): | |
| use_segment = use_segments | |
| result = results[filename] | |
| lang = result['language'] | |
| language = LANGUAGES[lang] if lang in LANGUAGES else lang | |
| normalizer = EnglishTextNormalizer() if language and language == "english" else BasicTextNormalizer() | |
| # check if unsegmented text exceeds 200 characters | |
| if not use_segment: | |
| if len(result['text']) > 200: | |
| message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}" | |
| print(message) | |
| messages.append(message) | |
| use_segment = True | |
| # check if unsegmented audio exceeds 11.6s | |
| if not use_segment: | |
| path = f'{indir}/audio/{filename}' | |
| if not os.path.exists(path): | |
| messages.append(f"Missing source audio: {filename}") | |
| errored += 1 | |
| continue | |
| metadata = torchaudio.info(path) | |
| duration = metadata.num_frames / metadata.sample_rate | |
| if duration >= MAX_TRAINING_DURATION: | |
| message = f"Audio too large, using segments: {filename}" | |
| print(message) | |
| messages.append(message) | |
| use_segment = True | |
| # implicitly segment | |
| if use_segment and not use_segments: | |
| exists = True | |
| for segment in result['segments']: | |
| duration = segment['end'] - segment['start'] | |
| if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: | |
| continue | |
| path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") | |
| if os.path.exists(path): | |
| continue | |
| exists = False | |
| break | |
| if not exists: | |
| tmp = {} | |
| tmp[filename] = result | |
| print(f"Audio not segmented, segmenting: {filename}") | |
| message = slice_dataset( voice, results=tmp ) | |
| print(message) | |
| messages = messages + message.split("\n") | |
| if not use_segment: | |
| segments[filename] = { | |
| 'text': result['text'], | |
| 'lang': lang, | |
| 'language': language, | |
| 'normalizer': normalizer, | |
| 'phonemes': result['phonemes'] if 'phonemes' in result else None | |
| } | |
| else: | |
| for segment in result['segments']: | |
| duration = segment['end'] - segment['start'] | |
| if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration: | |
| continue | |
| segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = { | |
| 'text': segment['text'], | |
| 'lang': lang, | |
| 'language': language, | |
| 'normalizer': normalizer, | |
| 'phonemes': segment['phonemes'] if 'phonemes' in segment else None | |
| } | |
| jobs = { | |
| 'quantize': [[], []], | |
| 'phonemize': [[], []], | |
| } | |
| for file in enumerate_progress(segments, desc="Parsing segments", progress=progress): | |
| result = segments[file] | |
| path = f'{indir}/audio/{file}' | |
| if not os.path.exists(path): | |
| message = f"Missing segment, skipping... {file}" | |
| print(message) | |
| messages.append(message) | |
| errored += 1 | |
| continue | |
| text = result['text'] | |
| lang = result['lang'] | |
| language = result['language'] | |
| normalizer = result['normalizer'] | |
| phonemes = result['phonemes'] | |
| if phonemize and phonemes is None: | |
| phonemes = phonemizer( text, language=lang ) | |
| normalized = normalizer(text) if normalize else text | |
| if len(text) > 200: | |
| message = f"Text length too long (200 < {len(text)}), skipping... {file}" | |
| print(message) | |
| messages.append(message) | |
| errored += 1 | |
| continue | |
| waveform, sample_rate = torchaudio.load(path) | |
| num_channels, num_frames = waveform.shape | |
| duration = num_frames / sample_rate | |
| error = validate_waveform( waveform, sample_rate ) | |
| if error: | |
| message = f"{error}, skipping... {file}" | |
| print(message) | |
| messages.append(message) | |
| errored += 1 | |
| continue | |
| culled = len(text) < text_length | |
| if not culled and audio_length > 0: | |
| culled = duration < audio_length | |
| line = f'audio/{file}|{phonemes if phonemize and phonemes else text}' | |
| lines['training' if not culled else 'validation'].append(line) | |
| if culled or args.tts_backend != "vall-e": | |
| continue | |
| os.makedirs(f'{indir}/valle/', exist_ok=True) | |
| qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}' | |
| if not os.path.exists(qnt_file): | |
| jobs['quantize'][0].append(qnt_file) | |
| jobs['quantize'][1].append((waveform, sample_rate)) | |
| """ | |
| quantized = valle_quantize( waveform, sample_rate ).cpu() | |
| torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') | |
| print("Quantized:", file) | |
| """ | |
| phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}' | |
| if not os.path.exists(phn_file): | |
| jobs['phonemize'][0].append(phn_file) | |
| jobs['phonemize'][1].append(normalized) | |
| """ | |
| phonemized = valle_phonemize( normalized ) | |
| open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized)) | |
| print("Phonemized:", file, normalized, text) | |
| """ | |
| for i in enumerate_progress(range(len(jobs['quantize'][0])), desc="Quantizing", progress=progress): | |
| qnt_file = jobs['quantize'][0][i] | |
| waveform, sample_rate = jobs['quantize'][1][i] | |
| quantized = valle_quantize( waveform, sample_rate ).cpu() | |
| torch.save(quantized, qnt_file) | |
| print("Quantized:", qnt_file) | |
| for i in enumerate_progress(range(len(jobs['phonemize'][0])), desc="Phonemizing", progress=progress): | |
| phn_file = jobs['phonemize'][0][i] | |
| normalized = jobs['phonemize'][1][i] | |
| try: | |
| phonemized = valle_phonemize( normalized ) | |
| open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) | |
| print("Phonemized:", phn_file) | |
| except Exception as e: | |
| message = f"Failed to phonemize: {phn_file}: {normalized}" | |
| messages.append(message) | |
| print(message) | |
| training_joined = "\n".join(lines['training']) | |
| validation_joined = "\n".join(lines['validation']) | |
| with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: | |
| f.write(training_joined) | |
| with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: | |
| f.write(validation_joined) | |
| messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}") | |
| return "\n".join(messages) | |
| def calc_iterations( epochs, lines, batch_size ): | |
| return int(math.ceil(epochs * math.ceil(lines / batch_size))) | |
| def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ): | |
| return [int(iterations * d) for d in schedule] | |
| def optimize_training_settings( **kwargs ): | |
| messages = [] | |
| settings = {} | |
| settings.update(kwargs) | |
| dataset_path = f"./training/{settings['voice']}/train.txt" | |
| with open(dataset_path, 'r', encoding="utf-8") as f: | |
| lines = len(f.readlines()) | |
| if lines == 0: | |
| raise Exception("Empty dataset.") | |
| if settings['batch_size'] > lines: | |
| settings['batch_size'] = lines | |
| messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}") | |
| """ | |
| if lines % settings['batch_size'] != 0: | |
| settings['batch_size'] = int(lines / settings['batch_size']) | |
| if settings['batch_size'] == 0: | |
| settings['batch_size'] = 1 | |
| messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {settings['batch_size']}") | |
| """ | |
| if settings['gradient_accumulation_size'] == 0: | |
| settings['gradient_accumulation_size'] = 1 | |
| if settings['batch_size'] / settings['gradient_accumulation_size'] < 2: | |
| settings['gradient_accumulation_size'] = int(settings['batch_size'] / 2) | |
| if settings['gradient_accumulation_size'] == 0: | |
| settings['gradient_accumulation_size'] = 1 | |
| messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}") | |
| elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0: | |
| settings['gradient_accumulation_size'] -= settings['batch_size'] % settings['gradient_accumulation_size'] | |
| if settings['gradient_accumulation_size'] == 0: | |
| settings['gradient_accumulation_size'] = 1 | |
| messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}") | |
| if settings['batch_size'] % settings['gpus'] != 0: | |
| settings['batch_size'] -= settings['batch_size'] % settings['gpus'] | |
| if settings['batch_size'] == 0: | |
| settings['batch_size'] = 1 | |
| messages.append(f"Batch size not neatly divisible by GPU count, adjusting batch size to: {settings['batch_size']}") | |
| def get_device_batch_size( vram ): | |
| DEVICE_BATCH_SIZE_MAP = [ | |
| (70, 128), # based on an A100-80G, I can safely get a ratio of 4096:32 = 128 | |
| (32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 128:2 = 64 | |
| (16, 8), # based on an A4000, I can do a ratio of 512:64 = 8:1 | |
| (8, 4), # interpolated | |
| (6, 2), # based on my 2060, it only really lets me have a batch ratio of 2:1 | |
| ] | |
| for k, v in DEVICE_BATCH_SIZE_MAP: | |
| if vram > (k-1): | |
| return v | |
| return 1 | |
| if settings['gpus'] > get_device_count(): | |
| settings['gpus'] = get_device_count() | |
| messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}") | |
| if settings['gpus'] <= 1: | |
| settings['gpus'] = 1 | |
| else: | |
| messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.") | |
| # assuming you have equal GPUs | |
| vram = get_device_vram() * settings['gpus'] | |
| batch_ratio = int(settings['batch_size'] / settings['gradient_accumulation_size']) | |
| batch_cap = get_device_batch_size(vram) | |
| if batch_ratio > batch_cap: | |
| settings['gradient_accumulation_size'] = int(settings['batch_size'] / batch_cap) | |
| messages.append(f"Batch ratio ({batch_ratio}) is expected to exceed your VRAM capacity ({'{:.3f}'.format(vram)}GB, suggested {batch_cap} batch size cap), adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}") | |
| iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) | |
| if settings['epochs'] < settings['save_rate']: | |
| settings['save_rate'] = settings['epochs'] | |
| messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}") | |
| if settings['epochs'] < settings['validation_rate']: | |
| settings['validation_rate'] = settings['epochs'] | |
| messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {settings['validation_rate']}") | |
| if settings['resume_state'] and not os.path.exists(settings['resume_state']): | |
| settings['resume_state'] = None | |
| messages.append("Resume path specified, but does not exist. Disabling...") | |
| if settings['bitsandbytes']: | |
| messages.append("! EXPERIMENTAL ! BitsAndBytes requested.") | |
| if settings['half_p']: | |
| if settings['bitsandbytes']: | |
| settings['half_p'] = False | |
| messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...") | |
| else: | |
| messages.append("! EXPERIMENTAL ! Half Precision requested.") | |
| if not os.path.exists(get_halfp_model_path()): | |
| convert_to_halfp() | |
| steps = int(iterations / settings['epochs']) | |
| messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({steps}) steps per epoch)") | |
| return settings, messages | |
| def save_training_settings( **kwargs ): | |
| messages = [] | |
| settings = {} | |
| settings.update(kwargs) | |
| outjson = f'./training/{settings["voice"]}/train.json' | |
| with open(outjson, 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(settings, indent='\t') ) | |
| settings['dataset_path'] = f"./training/{settings['voice']}/train.txt" | |
| settings['validation_path'] = f"./training/{settings['voice']}/validation.txt" | |
| with open(settings['dataset_path'], 'r', encoding="utf-8") as f: | |
| lines = len(f.readlines()) | |
| settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) | |
| if not settings['source_model'] or settings['source_model'] == "auto": | |
| settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth" | |
| if settings['half_p']: | |
| if not os.path.exists(get_halfp_model_path()): | |
| convert_to_halfp() | |
| messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps") | |
| iterations_per_epoch = settings['iterations'] / settings['epochs'] | |
| settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch) | |
| settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch) | |
| iterations_per_epoch = int(iterations_per_epoch) | |
| if settings['save_rate'] < 1: | |
| settings['save_rate'] = 1 | |
| """ | |
| if settings['validation_rate'] < 1: | |
| settings['validation_rate'] = 1 | |
| """ | |
| """ | |
| if settings['iterations'] % settings['save_rate'] != 0: | |
| adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate'] | |
| messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}") | |
| settings['iterations'] = adjustment | |
| """ | |
| settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) | |
| if not os.path.exists(settings['validation_path']): | |
| settings['validation_enabled'] = False | |
| messages.append("Validation not found, disabling validation...") | |
| elif settings['validation_batch_size'] == 0: | |
| settings['validation_enabled'] = False | |
| messages.append("Validation batch size == 0, disabling validation...") | |
| else: | |
| with open(settings['validation_path'], 'r', encoding="utf-8") as f: | |
| validation_lines = len(f.readlines()) | |
| if validation_lines < settings['validation_batch_size']: | |
| settings['validation_batch_size'] = validation_lines | |
| messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}") | |
| settings['tokenizer_json'] = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0] | |
| if settings['gpus'] > get_device_count(): | |
| settings['gpus'] = get_device_count() | |
| # what an utter mistake this was | |
| settings['optimizer'] = 'adamw' # if settings['gpus'] == 1 else 'adamw_zero' | |
| if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES: | |
| settings['learning_rate_scheme'] = "Multistep" | |
| settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[settings['learning_rate_scheme']] | |
| learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"] | |
| if settings['learning_rate_scheme'] == "MultiStepLR": | |
| if not settings['learning_rate_schedule']: | |
| settings['learning_rate_schedule'] = LEARNING_RATE_SCHEDULE | |
| elif isinstance(settings['learning_rate_schedule'],str): | |
| settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule']) | |
| settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] ) | |
| learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}") | |
| learning_rate_schema.append(f" lr_gamma: 0.5") | |
| elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart": | |
| epochs = settings['epochs'] | |
| restarts = settings['learning_rate_restarts'] | |
| restart_period = int(epochs / restarts) | |
| if 'learning_rate_warmup' not in settings: | |
| settings['learning_rate_warmup'] = 0 | |
| if 'learning_rate_min' not in settings: | |
| settings['learning_rate_min'] = 1e-08 | |
| if 'learning_rate_period' not in settings: | |
| settings['learning_rate_period'] = [ iterations_per_epoch * restart_period for x in range(epochs) ] | |
| settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * restart_period for x in range(restarts) ] # [52, 104, 156, 208] | |
| if 'learning_rate_restart_weights' not in settings: | |
| settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125] | |
| settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5 | |
| learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}") | |
| learning_rate_schema.append(f" warmup: {settings['learning_rate_warmup']}") | |
| learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}") | |
| learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}") | |
| learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}") | |
| settings['learning_rate_scheme'] = "\n".join(learning_rate_schema) | |
| if settings['resume_state']: | |
| settings['source_model'] = f"# pretrain_model_gpt: '{settings['source_model']}'" | |
| settings['resume_state'] = f"resume_state: '{settings['resume_state']}'" | |
| else: | |
| settings['source_model'] = f"pretrain_model_gpt: '{settings['source_model']}'" | |
| settings['resume_state'] = f"# resume_state: '{settings['resume_state']}'" | |
| def use_template(template, out): | |
| with open(template, 'r', encoding="utf-8") as f: | |
| yaml = f.read() | |
| # i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals | |
| for k in settings: | |
| if settings[k] is None: | |
| continue | |
| yaml = yaml.replace(f"${{{k}}}", str(settings[k])) | |
| with open(out, 'w', encoding="utf-8") as f: | |
| f.write(yaml) | |
| if args.tts_backend == "tortoise": | |
| use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml') | |
| elif args.tts_backend == "vall-e": | |
| settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]" | |
| use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml') | |
| messages.append(f"Saved training output") | |
| return settings, messages | |
| def import_voices(files, saveAs=None, progress=None): | |
| global args | |
| if not isinstance(files, list): | |
| files = [files] | |
| for file in enumerate_progress(files, desc="Importing voice files", progress=progress): | |
| j, latents = read_generate_settings(file, read_latents=True) | |
| if j is not None and saveAs is None: | |
| saveAs = j['voice'] | |
| if saveAs is None or saveAs == "": | |
| raise Exception("Specify a voice name") | |
| outdir = f'{get_voice_dir()}/{saveAs}/' | |
| os.makedirs(outdir, exist_ok=True) | |
| if latents: | |
| print(f"Importing latents to {latents}") | |
| with open(f'{outdir}/cond_latents.pth', 'wb') as f: | |
| f.write(latents) | |
| latents = f'{outdir}/cond_latents.pth' | |
| print(f"Imported latents to {latents}") | |
| else: | |
| filename = file.name | |
| if filename[-4:] != ".wav": | |
| raise Exception("Please convert to a WAV first") | |
| path = f"{outdir}/{os.path.basename(filename)}" | |
| print(f"Importing voice to {path}") | |
| waveform, sample_rate = torchaudio.load(filename) | |
| if args.voice_fixer: | |
| if not voicefixer: | |
| load_voicefixer() | |
| waveform, sample_rate = resample(waveform, sample_rate, 44100) | |
| torchaudio.save(path, waveform, sample_rate) | |
| print(f"Running 'voicefixer' on voice sample: {path}") | |
| voicefixer.restore( | |
| input = path, | |
| output = path, | |
| cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, | |
| #mode=mode, | |
| ) | |
| else: | |
| torchaudio.save(path, waveform, sample_rate) | |
| print(f"Imported voice to {path}") | |
| def relative_paths( dirs ): | |
| return [ './' + os.path.relpath( d ).replace("\\", "/") for d in dirs ] | |
| def get_voice( name, dir=get_voice_dir(), load_latents=True ): | |
| subj = f'{dir}/{name}/' | |
| if not os.path.isdir(subj): | |
| return | |
| voice = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.flac')) | |
| if load_latents: | |
| voice = voice + list(glob(f'{subj}/*.pth')) | |
| return sorted( voice ) | |
| def get_voice_list(dir=get_voice_dir(), append_defaults=False): | |
| defaults = [ "random", "microphone" ] | |
| os.makedirs(dir, exist_ok=True) | |
| #res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) | |
| res = [] | |
| for name in os.listdir(dir): | |
| if name in defaults: | |
| continue | |
| if not os.path.isdir(f'{dir}/{name}'): | |
| continue | |
| if len(os.listdir(os.path.join(dir, name))) == 0: | |
| continue | |
| files = get_voice( name, dir=dir ) | |
| if len(files) > 0: | |
| res.append(name) | |
| else: | |
| for subdir in os.listdir(f'{dir}/{name}'): | |
| if not os.path.isdir(f'{dir}/{name}/{subdir}'): | |
| continue | |
| files = get_voice( f'{name}/{subdir}', dir=dir ) | |
| if len(files) == 0: | |
| continue | |
| res.append(f'{name}/{subdir}') | |
| res = sorted(res) | |
| if append_defaults: | |
| res = res + defaults | |
| return res | |
| def get_valle_models(dir="./training/"): | |
| return [ f'{dir}/{d}/config.yaml' for d in os.listdir(dir) if os.path.exists(f'{dir}/{d}/config.yaml') ] | |
| def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False): | |
| os.makedirs(dir, exist_ok=True) | |
| base = [get_model_path('autoregressive.pth')] | |
| halfp = get_halfp_model_path() | |
| if os.path.exists(halfp): | |
| base.append(halfp) | |
| additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) | |
| found = [] | |
| for training in os.listdir(f'./training/'): | |
| if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/finetune/') or not os.path.isdir(f'./training/{training}/finetune/models/'): | |
| continue | |
| models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/finetune/models/') if d[-8:] == "_gpt.pth" ]) | |
| found = found + [ f'./training/{training}/finetune/models/{d}_gpt.pth' for d in models ] | |
| res = base + additionals + found | |
| if prefixed: | |
| for i in range(len(res)): | |
| path = res[i] | |
| hash = hash_file(path) | |
| shorthash = hash[:8] | |
| res[i] = f'[{shorthash}] {path}' | |
| paths = relative_paths(res) | |
| if auto: | |
| paths = ["auto"] + paths | |
| return paths | |
| def get_diffusion_models(dir="./models/finetunes/", prefixed=False): | |
| return relative_paths([ get_model_path('diffusion_decoder.pth') ]) | |
| def get_tokenizer_jsons( dir="./models/tokenizers/" ): | |
| additionals = sorted([ f'{dir}/{d}' for d in os.listdir(dir) if d[-5:] == ".json" ]) if os.path.isdir(dir) else [] | |
| return relative_paths([ "./modules/tortoise-tts/tortoise/data/tokenizer.json" ] + additionals) | |
| def tokenize_text( text, config=None, stringed=True, skip_specials=False ): | |
| from tortoise.utils.tokenizer import VoiceBpeTokenizer | |
| if not config: | |
| config = args.tokenizer_json if args.tokenizer_json else get_tokenizer_jsons()[0] | |
| if not tts: | |
| tokenizer = VoiceBpeTokenizer(config) | |
| else: | |
| tokenizer = tts.tokenizer | |
| encoded = tokenizer.encode(text) | |
| decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=skip_specials).split(" ") | |
| if stringed: | |
| return "\n".join([ str(encoded), str(decoded) ]) | |
| return decoded | |
| def get_dataset_list(dir="./training/"): | |
| return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ]) | |
| def get_training_list(dir="./training/"): | |
| if args.tts_backend == "tortoise": | |
| return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) | |
| else: | |
| return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ]) | |
| def pad(num, zeroes): | |
| return str(num).zfill(zeroes+1) | |
| def curl(url): | |
| try: | |
| req = urllib.request.Request(url, headers={'User-Agent': 'Python'}) | |
| conn = urllib.request.urlopen(req) | |
| data = conn.read() | |
| data = data.decode() | |
| data = json.loads(data) | |
| conn.close() | |
| return data | |
| except Exception as e: | |
| print(e) | |
| return None | |
| def check_for_updates( dir = None ): | |
| if dir is None: | |
| check_for_updates("./.git/") | |
| check_for_updates("./.git/modules/dlas/") | |
| check_for_updates("./.git/modules/tortoise-tts/") | |
| return | |
| git_dir = dir | |
| if not os.path.isfile(f'{git_dir}/FETCH_HEAD'): | |
| print(f"Cannot check for updates for {dir}: not from a git repo") | |
| return False | |
| with open(f'{git_dir}/FETCH_HEAD', 'r', encoding="utf-8") as f: | |
| head = f.read() | |
| match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head) | |
| if match is None or len(match) == 0: | |
| print(f"Cannot check for updates for {dir}: cannot parse FETCH_HEAD") | |
| return False | |
| match = match[0] | |
| local = match[0] | |
| host = match[1] | |
| owner = match[2] | |
| repo = match[3] | |
| res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances | |
| if res is None or len(res) == 0: | |
| print(f"Cannot check for updates for {dir}: cannot fetch from remote") | |
| return False | |
| remote = res[0]["commit"]["id"] | |
| if remote != local: | |
| print(f"New version found for {dir}: {local[:8]} => {remote[:8]}") | |
| return True | |
| return False | |
| def enumerate_progress(iterable, desc=None, progress=None, verbose=None): | |
| if verbose and desc is not None: | |
| print(desc) | |
| if progress is None: | |
| return tqdm(iterable, disable=False) #not verbose) | |
| return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True) | |
| def notify_progress(message, progress=None, verbose=True): | |
| if verbose: | |
| print(message) | |
| if progress is None: | |
| return | |
| progress(0, desc=message) | |
| def get_args(): | |
| global args | |
| return args | |
| def setup_args(): | |
| global args | |
| default_arguments = { | |
| 'share': False, | |
| 'listen': None, | |
| 'check-for-updates': False, | |
| 'models-from-local-only': False, | |
| 'low-vram': False, | |
| 'sample-batch-size': None, | |
| 'unsqueeze-sample-batches': False, | |
| 'embed-output-metadata': True, | |
| 'latents-lean-and-mean': True, | |
| 'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it | |
| 'voice-fixer-use-cuda': True, | |
| 'force-cpu-for-conditioning-latents': False, | |
| 'defer-tts-load': False, | |
| 'device-override': None, | |
| 'prune-nonfinal-outputs': True, | |
| 'concurrency-count': 2, | |
| 'autocalculate-voice-chunk-duration-size': 10, | |
| 'output-sample-rate': 44100, | |
| 'output-volume': 1, | |
| 'results-folder': "./results/", | |
| 'hf-token': None, | |
| 'tts-backend': TTSES[0], | |
| 'autoregressive-model': None, | |
| 'diffusion-model': None, | |
| 'vocoder-model': VOCODERS[-1], | |
| 'tokenizer-json': None, | |
| 'phonemizer-backend': 'espeak', | |
| 'valle-model': None, | |
| 'whisper-backend': 'openai/whisper', | |
| 'whisper-model': "base", | |
| 'whisper-batchsize': 1, | |
| 'training-default-halfp': False, | |
| 'training-default-bnb': True, | |
| } | |
| if os.path.isfile('./config/exec.json'): | |
| with open(f'./config/exec.json', 'r', encoding="utf-8") as f: | |
| try: | |
| overrides = json.load(f) | |
| for k in overrides: | |
| default_arguments[k] = overrides[k] | |
| except Exception as e: | |
| print(e) | |
| pass | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") | |
| parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on") | |
| parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup") | |
| parser.add_argument("--models-from-local-only", action='store_true', default=default_arguments['models-from-local-only'], help="Only loads models from disk, does not check for updates for models") | |
| parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage") | |
| parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)") | |
| parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.") | |
| parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.") | |
| parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.") | |
| parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") | |
| parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") | |
| parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation") | |
| parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") | |
| parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass") | |
| parser.add_argument("--unsqueeze-sample-batches", default=default_arguments['unsqueeze-sample-batches'], action='store_true', help="Unsqueezes sample batches to process one by one after sampling") | |
| parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") | |
| parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)") | |
| parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") | |
| parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") | |
| parser.add_argument("--results-folder", type=str, default=default_arguments['results-folder'], help="Sets output directory") | |
| parser.add_argument("--hf-token", type=str, default=default_arguments['hf-token'], help="HuggingFace Token") | |
| parser.add_argument("--tts-backend", default=default_arguments['tts-backend'], help="Specifies which TTS backend to use.") | |
| parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") | |
| parser.add_argument("--diffusion-model", default=default_arguments['diffusion-model'], help="Specifies which diffusion model to use for sampling.") | |
| parser.add_argument("--vocoder-model", default=default_arguments['vocoder-model'], action='store_true', help="Specifies with vocoder to use") | |
| parser.add_argument("--tokenizer-json", default=default_arguments['tokenizer-json'], help="Specifies which tokenizer json to use for tokenizing.") | |
| parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.") | |
| parser.add_argument("--valle-model", default=default_arguments['valle-model'], help="Specifies which VALL-E model to use for sampling.") | |
| parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)") | |
| parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") | |
| parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX") | |
| parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") | |
| parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") | |
| parser.add_argument("--os", default="unix", help="Specifies which OS, easily") | |
| args = parser.parse_args() | |
| args.embed_output_metadata = not args.no_embed_output_metadata | |
| if not args.device_override: | |
| set_device_name(args.device_override) | |
| if args.sample_batch_size == 0 and get_device_batch_size() == 1: | |
| print("!WARNING! Automatically deduced sample batch size returned 1.") | |
| args.listen_host = None | |
| args.listen_port = None | |
| args.listen_path = None | |
| if args.listen: | |
| try: | |
| match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0] | |
| args.listen_host = match[0] if match[0] != "" else "127.0.0.1" | |
| args.listen_port = match[1] if match[1] != "" else None | |
| args.listen_path = match[2] if match[2] != "" else "/" | |
| except Exception as e: | |
| pass | |
| if args.listen_port is not None: | |
| args.listen_port = int(args.listen_port) | |
| if args.listen_port == 0: | |
| args.listen_port = None | |
| return args | |
| def get_default_settings( hypenated=True ): | |
| settings = { | |
| 'listen': None if not args.listen else args.listen, | |
| 'share': args.share, | |
| 'low-vram':args.low_vram, | |
| 'check-for-updates':args.check_for_updates, | |
| 'models-from-local-only':args.models_from_local_only, | |
| 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, | |
| 'defer-tts-load': args.defer_tts_load, | |
| 'prune-nonfinal-outputs': args.prune_nonfinal_outputs, | |
| 'device-override': args.device_override, | |
| 'sample-batch-size': args.sample_batch_size, | |
| 'unsqueeze-sample-batches': args.unsqueeze_sample_batches, | |
| 'embed-output-metadata': args.embed_output_metadata, | |
| 'latents-lean-and-mean': args.latents_lean_and_mean, | |
| 'voice-fixer': args.voice_fixer, | |
| 'voice-fixer-use-cuda': args.voice_fixer_use_cuda, | |
| 'concurrency-count': args.concurrency_count, | |
| 'output-sample-rate': args.output_sample_rate, | |
| 'autocalculate-voice-chunk-duration-size': args.autocalculate_voice_chunk_duration_size, | |
| 'output-volume': args.output_volume, | |
| 'results-folder': args.results_folder, | |
| 'hf-token': args.hf_token, | |
| 'tts-backend': args.tts_backend, | |
| 'autoregressive-model': args.autoregressive_model, | |
| 'diffusion-model': args.diffusion_model, | |
| 'vocoder-model': args.vocoder_model, | |
| 'tokenizer-json': args.tokenizer_json, | |
| 'phonemizer-backend': args.phonemizer_backend, | |
| 'valle-model': args.valle_model, | |
| 'whisper-backend': args.whisper_backend, | |
| 'whisper-model': args.whisper_model, | |
| 'whisper-batchsize': args.whisper_batchsize, | |
| 'training-default-halfp': args.training_default_halfp, | |
| 'training-default-bnb': args.training_default_bnb, | |
| } | |
| res = {} | |
| for k in settings: | |
| res[k.replace("-", "_") if not hypenated else k] = settings[k] | |
| return res | |
| def update_args( **kwargs ): | |
| global args | |
| settings = get_default_settings(hypenated=False) | |
| settings.update(kwargs) | |
| args.listen = settings['listen'] | |
| args.share = settings['share'] | |
| args.check_for_updates = settings['check_for_updates'] | |
| args.models_from_local_only = settings['models_from_local_only'] | |
| args.low_vram = settings['low_vram'] | |
| args.force_cpu_for_conditioning_latents = settings['force_cpu_for_conditioning_latents'] | |
| args.defer_tts_load = settings['defer_tts_load'] | |
| args.prune_nonfinal_outputs = settings['prune_nonfinal_outputs'] | |
| args.device_override = settings['device_override'] | |
| args.sample_batch_size = settings['sample_batch_size'] | |
| args.unsqueeze_sample_batches = settings['unsqueeze_sample_batches'] | |
| args.embed_output_metadata = settings['embed_output_metadata'] | |
| args.latents_lean_and_mean = settings['latents_lean_and_mean'] | |
| args.voice_fixer = settings['voice_fixer'] | |
| args.voice_fixer_use_cuda = settings['voice_fixer_use_cuda'] | |
| args.concurrency_count = settings['concurrency_count'] | |
| args.output_sample_rate = 44000 | |
| args.autocalculate_voice_chunk_duration_size = settings['autocalculate_voice_chunk_duration_size'] | |
| args.output_volume = settings['output_volume'] | |
| args.results_folder = settings['results_folder'] | |
| args.hf_token = settings['hf_token'] | |
| args.tts_backend = settings['tts_backend'] | |
| args.autoregressive_model = settings['autoregressive_model'] | |
| args.diffusion_model = settings['diffusion_model'] | |
| args.vocoder_model = settings['vocoder_model'] | |
| args.tokenizer_json = settings['tokenizer_json'] | |
| args.phonemizer_backend = settings['phonemizer_backend'] | |
| args.valle_model = settings['valle_model'] | |
| args.whisper_backend = settings['whisper_backend'] | |
| args.whisper_model = settings['whisper_model'] | |
| args.whisper_batchsize = settings['whisper_batchsize'] | |
| args.training_default_halfp = settings['training_default_halfp'] | |
| args.training_default_bnb = settings['training_default_bnb'] | |
| save_args_settings() | |
| def save_args_settings(): | |
| global args | |
| settings = get_default_settings() | |
| os.makedirs('./config/', exist_ok=True) | |
| with open(f'./config/exec.json', 'w', encoding="utf-8") as f: | |
| f.write(json.dumps(settings, indent='\t') ) | |
| # super kludgy )`; | |
| def import_generate_settings(file = None): | |
| if not file: | |
| file = "./config/generate.json" | |
| res = { | |
| 'text': None, | |
| 'delimiter': None, | |
| 'emotion': None, | |
| 'prompt': None, | |
| 'voice': "random", | |
| 'mic_audio': None, | |
| 'voice_latents_chunks': None, | |
| 'candidates': None, | |
| 'seed': None, | |
| 'num_autoregressive_samples': 16, | |
| 'diffusion_iterations': 30, | |
| 'temperature': 0.8, | |
| 'diffusion_sampler': "DDIM", | |
| 'breathing_room': 8 , | |
| 'cvvp_weight': 0.0, | |
| 'top_p': 0.8, | |
| 'diffusion_temperature': 1.0, | |
| 'length_penalty': 1.0, | |
| 'repetition_penalty': 2.0, | |
| 'cond_free_k': 2.0, | |
| 'experimentals': None, | |
| } | |
| settings, _ = read_generate_settings(file, read_latents=False) | |
| if settings is not None: | |
| res.update(settings) | |
| return res | |
| def reset_generate_settings(): | |
| with open(f'./config/generate.json', 'w', encoding="utf-8") as f: | |
| f.write(json.dumps({}, indent='\t') ) | |
| return import_generate_settings() | |
| def read_generate_settings(file, read_latents=True): | |
| j = None | |
| latents = None | |
| if isinstance(file, list) and len(file) == 1: | |
| file = file[0] | |
| try: | |
| if file is not None: | |
| if hasattr(file, 'name'): | |
| file = file.name | |
| if file[-4:] == ".wav": | |
| metadata = music_tag.load_file(file) | |
| if 'lyrics' in metadata: | |
| j = json.loads(str(metadata['lyrics'])) | |
| elif file[-5:] == ".json": | |
| with open(file, 'r') as f: | |
| j = json.load(f) | |
| except Exception as e: | |
| pass | |
| if j is not None: | |
| if 'latents' in j: | |
| if read_latents: | |
| latents = base64.b64decode(j['latents']) | |
| del j['latents'] | |
| if "time" in j: | |
| j["time"] = "{:.3f}".format(j["time"]) | |
| return ( | |
| j, | |
| latents, | |
| ) | |
| def version_check_tts( min_version ): | |
| global tts | |
| if not tts: | |
| raise Exception("TTS is not initialized") | |
| if not hasattr(tts, 'version'): | |
| return False | |
| if min_version[0] > tts.version[0]: | |
| return True | |
| if min_version[1] > tts.version[1]: | |
| return True | |
| if min_version[2] >= tts.version[2]: | |
| return True | |
| return False | |
| def load_tts( restart=False, | |
| # TorToiSe configs | |
| autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None, | |
| # VALL-E configs | |
| valle_model=None, | |
| ): | |
| global args | |
| global tts | |
| if restart: | |
| unload_tts() | |
| tts_loading = True | |
| if args.tts_backend == "tortoise": | |
| if autoregressive_model: | |
| args.autoregressive_model = autoregressive_model | |
| else: | |
| autoregressive_model = args.autoregressive_model | |
| if autoregressive_model == "auto": | |
| autoregressive_model = deduce_autoregressive_model() | |
| if diffusion_model: | |
| args.diffusion_model = diffusion_model | |
| else: | |
| diffusion_model = args.diffusion_model | |
| if vocoder_model: | |
| args.vocoder_model = vocoder_model | |
| else: | |
| vocoder_model = args.vocoder_model | |
| if tokenizer_json: | |
| args.tokenizer_json = tokenizer_json | |
| else: | |
| tokenizer_json = args.tokenizer_json | |
| if get_device_name() == "cpu": | |
| print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.") | |
| print(f"Loading TorToiSe... (AR: {autoregressive_model}, diffusion: {diffusion_model}, vocoder: {vocoder_model})") | |
| tts = TorToise_TTS(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches) | |
| elif args.tts_backend == "vall-e": | |
| if valle_model: | |
| args.valle_model = valle_model | |
| else: | |
| valle_model = args.valle_model | |
| print(f"Loading VALL-E... (Config: {valle_model})") | |
| tts = VALLE_TTS(config=args.valle_model) | |
| print("Loaded TTS, ready for generation.") | |
| tts_loading = False | |
| return tts | |
| def unload_tts(): | |
| global tts | |
| if tts: | |
| del tts | |
| tts = None | |
| print("Unloaded TTS") | |
| do_gc() | |
| def reload_tts(): | |
| unload_tts() | |
| load_tts() | |
| def get_current_voice(): | |
| global current_voice | |
| if current_voice: | |
| return current_voice | |
| settings, _ = read_generate_settings("./config/generate.json", read_latents=False) | |
| if settings and "voice" in settings['voice']: | |
| return settings["voice"] | |
| return None | |
| def deduce_autoregressive_model(voice=None): | |
| if not voice: | |
| voice = get_current_voice() | |
| if voice: | |
| if os.path.exists(f'./models/finetunes/{voice}.pth'): | |
| return f'./models/finetunes/{voice}.pth' | |
| dir = f'./training/{voice}/finetune/models/' | |
| if os.path.isdir(dir): | |
| counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) | |
| names = [ f'{dir}/{d}_gpt.pth' for d in counts ] | |
| if len(names) > 0: | |
| return names[-1] | |
| if args.autoregressive_model != "auto": | |
| return args.autoregressive_model | |
| return get_model_path('autoregressive.pth') | |
| def update_autoregressive_model(autoregressive_model_path): | |
| if args.tts_backend != "tortoise": | |
| raise f"Unsupported backend: {args.tts_backend}" | |
| match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path) | |
| if match: | |
| autoregressive_model_path = match[0] | |
| if not autoregressive_model_path or not os.path.exists(autoregressive_model_path): | |
| print(f"Invalid model: {autoregressive_model_path}") | |
| return | |
| args.autoregressive_model = autoregressive_model_path | |
| save_args_settings() | |
| print(f'Stored autoregressive model to settings: {autoregressive_model_path}') | |
| global tts | |
| if not tts: | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| return | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| if autoregressive_model_path == "auto": | |
| autoregressive_model_path = deduce_autoregressive_model() | |
| if autoregressive_model_path == tts.autoregressive_model_path: | |
| return | |
| tts.load_autoregressive_model(autoregressive_model_path) | |
| do_gc() | |
| return autoregressive_model_path | |
| def update_diffusion_model(diffusion_model_path): | |
| if args.tts_backend != "tortoise": | |
| raise f"Unsupported backend: {args.tts_backend}" | |
| match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path) | |
| if match: | |
| diffusion_model_path = match[0] | |
| if not diffusion_model_path or not os.path.exists(diffusion_model_path): | |
| print(f"Invalid model: {diffusion_model_path}") | |
| return | |
| args.diffusion_model = diffusion_model_path | |
| save_args_settings() | |
| print(f'Stored diffusion model to settings: {diffusion_model_path}') | |
| global tts | |
| if not tts: | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| return | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| if diffusion_model_path == "auto": | |
| diffusion_model_path = deduce_diffusion_model() | |
| if diffusion_model_path == tts.diffusion_model_path: | |
| return | |
| tts.load_diffusion_model(diffusion_model_path) | |
| do_gc() | |
| return diffusion_model_path | |
| def update_vocoder_model(vocoder_model): | |
| if args.tts_backend != "tortoise": | |
| raise f"Unsupported backend: {args.tts_backend}" | |
| args.vocoder_model = vocoder_model | |
| save_args_settings() | |
| print(f'Stored vocoder model to settings: {vocoder_model}') | |
| global tts | |
| if not tts: | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| return | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| print(f"Loading model: {vocoder_model}") | |
| tts.load_vocoder_model(vocoder_model) | |
| print(f"Loaded model: {tts.vocoder_model}") | |
| do_gc() | |
| return vocoder_model | |
| def update_tokenizer(tokenizer_json): | |
| if args.tts_backend != "tortoise": | |
| raise f"Unsupported backend: {args.tts_backend}" | |
| args.tokenizer_json = tokenizer_json | |
| save_args_settings() | |
| print(f'Stored tokenizer to settings: {tokenizer_json}') | |
| global tts | |
| if not tts: | |
| if tts_loading: | |
| raise Exception("TTS is still initializing...") | |
| return | |
| if hasattr(tts, "loading") and tts.loading: | |
| raise Exception("TTS is still initializing...") | |
| print(f"Loading tokenizer vocab: {tokenizer_json}") | |
| tts.load_tokenizer_json(tokenizer_json) | |
| print(f"Loaded tokenizer vocab: {tts.tokenizer_json}") | |
| do_gc() | |
| return vocoder_model | |
| def load_voicefixer(restart=False): | |
| global voicefixer | |
| if restart: | |
| unload_voicefixer() | |
| try: | |
| print("Loading Voicefixer") | |
| from voicefixer import VoiceFixer | |
| voicefixer = VoiceFixer() | |
| print("Loaded Voicefixer") | |
| except Exception as e: | |
| print(f"Error occurred while tring to initialize voicefixer: {e}") | |
| if voicefixer: | |
| del voicefixer | |
| voicefixer = None | |
| def unload_voicefixer(): | |
| global voicefixer | |
| if voicefixer: | |
| del voicefixer | |
| voicefixer = None | |
| print("Unloaded Voicefixer") | |
| do_gc() | |
| def load_whisper_model(language=None, model_name=None, progress=None): | |
| global whisper_model | |
| global whisper_vad | |
| global whisper_diarize | |
| global whisper_align_model | |
| if args.whisper_backend not in WHISPER_BACKENDS: | |
| raise Exception(f"unavailable backend: {args.whisper_backend}") | |
| if not model_name: | |
| model_name = args.whisper_model | |
| else: | |
| args.whisper_model = model_name | |
| save_args_settings() | |
| if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: | |
| model_name = f'{model_name}.{language}' | |
| print(f"Loading specialized model for language: {language}") | |
| notify_progress(f"Loading Whisper model: {model_name}", progress) | |
| if args.whisper_backend == "openai/whisper": | |
| import whisper | |
| try: | |
| #is it possible for model to fit on vram but go oom later on while executing on data? | |
| whisper_model = whisper.load_model(model_name) | |
| except: | |
| print("Out of VRAM memory. falling back to loading Whisper on CPU.") | |
| whisper_model = whisper.load_model(model_name, device="cpu") | |
| elif args.whisper_backend == "lightmare/whispercpp": | |
| from whispercpp import Whisper | |
| if not language: | |
| language = 'auto' | |
| b_lang = language.encode('ascii') | |
| whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang) | |
| elif args.whisper_backend == "m-bain/whisperx": | |
| import whisperx | |
| device = "cuda" if get_device_name() == "cuda" else "cpu" | |
| whisper_model = whisperx.load_model(model_name, device) | |
| if not args.hf_token: | |
| print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.") | |
| try: | |
| from pyannote.audio import Inference, Pipeline | |
| whisper_vad = Inference( | |
| "pyannote/segmentation", | |
| pre_aggregation_hook=lambda segmentation: segmentation, | |
| use_auth_token=args.hf_token, | |
| device=torch.device(device), | |
| ) | |
| # whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) | |
| except Exception as e: | |
| pass | |
| whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device) | |
| print("Loaded Whisper model") | |
| def unload_whisper(): | |
| global whisper_model | |
| global whisper_vad | |
| global whisper_diarize | |
| global whisper_align_model | |
| if whisper_vad: | |
| del whisper_vad | |
| whisper_vad = None | |
| if whisper_diarize: | |
| del whisper_diarize | |
| whisper_diarize = None | |
| if whisper_align_model: | |
| del whisper_align_model | |
| whisper_align_model = None | |
| if whisper_model: | |
| del whisper_model | |
| whisper_model = None | |
| print("Unloaded Whisper") | |
| do_gc() | |
| # shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74 | |
| def merge_models( primary_model_name, secondary_model_name, alpha, progress=gr.Progress() ): | |
| key_blacklist = [] | |
| def weighted_sum(theta0, theta1, alpha): | |
| return ((1 - alpha) * theta0) + (alpha * theta1) | |
| def read_model( filename ): | |
| print(f"Loading {filename}") | |
| return torch.load(filename) | |
| theta_func = weighted_sum | |
| theta_0 = read_model(primary_model_name) | |
| theta_1 = read_model(secondary_model_name) | |
| for key in enumerate_progress(theta_0.keys(), desc="Merging...", progress=progress): | |
| if key in key_blacklist: | |
| print("Skipping ignored key:", key) | |
| continue | |
| a = theta_0[key] | |
| b = theta_1[key] | |
| if a.dtype != torch.float32 and a.dtype != torch.float16: | |
| print("Skipping key:", key, a.dtype) | |
| continue | |
| if b.dtype != torch.float32 and b.dtype != torch.float16: | |
| print("Skipping key:", key, b.dtype) | |
| continue | |
| theta_0[key] = theta_func(a, b, alpha) | |
| del theta_1 | |
| primary_basename = os.path.splitext(os.path.basename(primary_model_name))[0] | |
| secondary_basename = os.path.splitext(os.path.basename(secondary_model_name))[0] | |
| suffix = "{:.3f}".format(alpha) | |
| output_path = f'./models/finetunes/{primary_basename}_{secondary_basename}_{suffix}_merge.pth' | |
| torch.save(theta_0, output_path) | |
| message = f"Saved to {output_path}" | |
| print(message) | |
| return message |