ecker commited on
Commit
c27ee3c
·
1 Parent(s): 166d491

added update checking for dlas and tortoise-tts, caching voices (for a given model and voice name) so random latents will remain the same

Browse files
Files changed (2) hide show
  1. src/utils.py +17 -5
  2. tortoise-tts +1 -1
src/utils.py CHANGED
@@ -103,8 +103,12 @@ def generate(
103
  if seed == 0:
104
  seed = None
105
 
 
106
  def fetch_voice( voice ):
107
- print(f"Loading voice: {voice}")
 
 
 
108
 
109
  sample_voice = None
110
  if voice == "microphone":
@@ -126,7 +130,8 @@ def generate(
126
  sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
127
  voice_samples = None
128
 
129
- return (voice_samples, conditioning_latents, sample_voice)
 
130
 
131
  def get_settings( override=None ):
132
  settings = {
@@ -1479,12 +1484,19 @@ def curl(url):
1479
  print(e)
1480
  return None
1481
 
1482
- def check_for_updates():
1483
- if not os.path.isfile('./.git/FETCH_HEAD'):
 
 
 
 
 
 
 
1484
  print("Cannot check for updates: not from a git repo")
1485
  return False
1486
 
1487
- with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
1488
  head = f.read()
1489
 
1490
  match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
 
103
  if seed == 0:
104
  seed = None
105
 
106
+ voice_cache = {}
107
  def fetch_voice( voice ):
108
+ print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
109
+ cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
110
+ if cache_key in voice_cache:
111
+ return voice_cache[cache_key]
112
 
113
  sample_voice = None
114
  if voice == "microphone":
 
130
  sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
131
  voice_samples = None
132
 
133
+ voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice)
134
+ return voice_cache[cache_key]
135
 
136
  def get_settings( override=None ):
137
  settings = {
 
1484
  print(e)
1485
  return None
1486
 
1487
+ def check_for_updates( dir = None ):
1488
+ if dir is None:
1489
+ check_for_updates("./")
1490
+ check_for_updates("./dlas/")
1491
+ check_for_updates("./tortoise-tts/")
1492
+ return
1493
+
1494
+ git_dir = f'{dir}/.git/'
1495
+ if not os.path.isfile(f'{git_dir}/FETCH_HEAD'):
1496
  print("Cannot check for updates: not from a git repo")
1497
  return False
1498
 
1499
+ with open(f'{git_dir}/FETCH_HEAD', 'r', encoding="utf-8") as f:
1500
  head = f.read()
1501
 
1502
  match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
tortoise-tts CHANGED
@@ -1 +1 @@
1
- Subproject commit 26133c20314b77155e77be804b43909dab9809d6
 
1
+ Subproject commit cc36c0997c8711889ef8028002fc9e41abd5c5f0