humair025 commited on
Commit
9ff76ea
·
verified ·
1 Parent(s): 4cba748

Update soprano/tts.py

Browse files
Files changed (1) hide show
  1. soprano/tts.py +16 -13
soprano/tts.py CHANGED
@@ -14,9 +14,12 @@ class SopranoTTS:
14
  device='cuda',
15
  cache_size_mb=10,
16
  decoder_batch_size=1):
17
- RECOGNIZED_DEVICES = ['cuda']
18
  RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
19
  assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
 
 
 
20
  if backend == 'auto':
21
  if device == 'cpu':
22
  backend = 'transformers'
@@ -31,21 +34,21 @@ class SopranoTTS:
31
 
32
  if backend == 'lmdeploy':
33
  from .backends.lmdeploy import LMDeployModel
34
- print("Imported lmdeploy.")
35
  self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
36
- print("Loaded model.")
37
  elif backend == 'transformers':
38
  from .backends.transformers import TransformersModel
39
  self.pipeline = TransformersModel(device=device)
40
 
41
- self.decoder = SopranoDecoder().cuda()
 
42
  decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
43
- self.decoder.load_state_dict(torch.load(decoder_path))
44
- self.decoder_batch_size=decoder_batch_size
45
- self.RECEPTIVE_FIELD = 4 # Decoder receptive field
46
- self.TOKEN_SIZE = 2048 # Number of samples per audio token
 
47
 
48
- self.infer("Hello world!") # warmup
49
 
50
  def _preprocess_text(self, texts, min_length=30):
51
  '''
@@ -139,8 +142,8 @@ class SopranoTTS:
139
  N = len(lengths)
140
  for i in range(N):
141
  batch_hidden_states.append(torch.cat([
142
- torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
143
- hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
144
  ], dim=2))
145
  batch_hidden_states = torch.cat(batch_hidden_states)
146
  with torch.no_grad():
@@ -182,7 +185,7 @@ class SopranoTTS:
182
  if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
183
  if finished or chunk_counter == chunk_size:
184
  batch_hidden_states = torch.stack(hidden_states_buffer)
185
- inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
186
  with torch.no_grad():
187
  audio = self.decoder(inp)[0]
188
  if finished:
@@ -194,4 +197,4 @@ class SopranoTTS:
194
  print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
195
  first_chunk = False
196
  yield audio_chunk.cpu()
197
- chunk_counter += 1
 
14
  device='cuda',
15
  cache_size_mb=10,
16
  decoder_batch_size=1):
17
+ RECOGNIZED_DEVICES = ['cuda', 'cpu'] # Added 'cpu' support
18
  RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
19
  assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
20
+
21
+ self.device = device # Store device for later use
22
+
23
  if backend == 'auto':
24
  if device == 'cpu':
25
  backend = 'transformers'
 
34
 
35
  if backend == 'lmdeploy':
36
  from .backends.lmdeploy import LMDeployModel
 
37
  self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
 
38
  elif backend == 'transformers':
39
  from .backends.transformers import TransformersModel
40
  self.pipeline = TransformersModel(device=device)
41
 
42
+ # Load decoder and move to appropriate device
43
+ self.decoder = SopranoDecoder().to(device)
44
  decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
45
+ self.decoder.load_state_dict(torch.load(decoder_path, map_location=device))
46
+
47
+ self.decoder_batch_size = decoder_batch_size
48
+ self.RECEPTIVE_FIELD = 4 # Decoder receptive field
49
+ self.TOKEN_SIZE = 2048 # Number of samples per audio token
50
 
51
+ self.infer("Hello world!") # warmup
52
 
53
  def _preprocess_text(self, texts, min_length=30):
54
  '''
 
142
  N = len(lengths)
143
  for i in range(N):
144
  batch_hidden_states.append(torch.cat([
145
+ torch.zeros((1, 512, lengths[0]-lengths[i]), device=self.device), # Use self.device
146
+ hidden_states[idx+i].unsqueeze(0).transpose(1,2).to(self.device).to(torch.float32), # Use self.device
147
  ], dim=2))
148
  batch_hidden_states = torch.cat(batch_hidden_states)
149
  with torch.no_grad():
 
185
  if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
186
  if finished or chunk_counter == chunk_size:
187
  batch_hidden_states = torch.stack(hidden_states_buffer)
188
+ inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).to(self.device).to(torch.float32) # Use self.device
189
  with torch.no_grad():
190
  audio = self.decoder(inp)[0]
191
  if finished:
 
197
  print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
198
  first_chunk = False
199
  yield audio_chunk.cpu()
200
+ chunk_counter += 1