ekwek commited on
Commit
ab4f445
·
verified ·
1 Parent(s): c2c4056

Upload 17 files

Browse files
soprano/backends/lmdeploy.py CHANGED
@@ -7,11 +7,16 @@ class LMDeployModel(BaseModel):
7
  def __init__(self,
8
  device='cuda',
9
  cache_size_mb=100,
 
10
  **kwargs):
11
  assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
12
  cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
13
  backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
14
- self.pipeline = pipeline('ekwek/Soprano-80M',
 
 
 
 
15
  log_level='ERROR',
16
  backend_config=backend_config)
17
 
 
7
  def __init__(self,
8
  device='cuda',
9
  cache_size_mb=100,
10
+ model_path=None,
11
  **kwargs):
12
  assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
13
  cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
14
  backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
15
+
16
+ # Use local model if path provided, otherwise use HuggingFace
17
+ model_name_or_path = model_path if model_path else 'ekwek/Soprano-80M'
18
+
19
+ self.pipeline = pipeline(model_name_or_path,
20
  log_level='ERROR',
21
  backend_config=backend_config)
22
 
soprano/backends/transformers.py CHANGED
@@ -1,20 +1,25 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  from .base import BaseModel
4
 
5
 
6
  class TransformersModel(BaseModel):
7
  def __init__(self,
8
  device='cuda',
 
9
  **kwargs):
10
  self.device = device
11
 
 
 
 
12
  self.model = AutoModelForCausalLM.from_pretrained(
13
- 'ekwek/Soprano-80M',
14
- torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
15
  device_map=device
16
  )
17
- self.tokenizer = AutoTokenizer.from_pretrained('ekwek/Soprano-80M')
18
  self.model.eval()
19
 
20
  def infer(self,
@@ -65,4 +70,81 @@ class TransformersModel(BaseModel):
65
  top_p=0.95,
66
  temperature=0.3,
67
  repetition_penalty=1.2):
68
- raise NotImplementedError("transformers backend does not currently support streaming, please consider using lmdeploy backend instead.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper
4
  from .base import BaseModel
5
 
6
 
7
  class TransformersModel(BaseModel):
8
  def __init__(self,
9
  device='cuda',
10
+ model_path=None,
11
  **kwargs):
12
  self.device = device
13
 
14
+ # Use local model if path provided, otherwise use HuggingFace
15
+ model_name_or_path = model_path if model_path else 'ekwek/Soprano-80M'
16
+
17
  self.model = AutoModelForCausalLM.from_pretrained(
18
+ model_name_or_path,
19
+ dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
20
  device_map=device
21
  )
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
23
  self.model.eval()
24
 
25
  def infer(self,
 
70
  top_p=0.95,
71
  temperature=0.3,
72
  repetition_penalty=1.2):
73
+
74
+ # Tokenize input
75
+ inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
76
+ input_ids = inputs['input_ids']
77
+
78
+ # Prepare Logits Processors for sampling
79
+ logits_processor = LogitsProcessorList()
80
+ if repetition_penalty != 1.0:
81
+ logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
82
+
83
+ logits_warper = LogitsProcessorList()
84
+ if temperature != 1.0:
85
+ logits_warper.append(TemperatureLogitsWarper(temperature=temperature))
86
+ if top_p < 1.0:
87
+ logits_warper.append(TopPLogitsWarper(top_p=top_p))
88
+
89
+ # Helper to sample next token
90
+ def get_next_token(logits, input_seq):
91
+ scores = logits_processor(input_seq, logits)
92
+ scores = logits_warper(input_seq, scores)
93
+ probs = torch.nn.functional.softmax(scores, dim=-1)
94
+ # Sample from the distribution
95
+ return torch.multinomial(probs, num_samples=1)
96
+
97
+ with torch.no_grad():
98
+ # Initial forward pass with the prompt
99
+ outputs = self.model(
100
+ input_ids,
101
+ use_cache=True,
102
+ output_hidden_states=True
103
+ )
104
+
105
+ past_key_values = outputs.past_key_values
106
+ next_token_logits = outputs.logits[:, -1, :]
107
+
108
+ # We need to maintain the full sequence for repetition penalty
109
+ generated_ids = input_ids
110
+
111
+ # Sample the first token
112
+ next_token = get_next_token(next_token_logits, generated_ids)
113
+
114
+ max_new_tokens = 512
115
+ eos_token_id = self.model.config.eos_token_id
116
+
117
+ for i in range(max_new_tokens):
118
+ # Append generated token to sequence history
119
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
120
+
121
+ # Run forward pass for the single new token
122
+ outputs = self.model(
123
+ next_token,
124
+ past_key_values=past_key_values,
125
+ use_cache=True,
126
+ output_hidden_states=True
127
+ )
128
+
129
+ # Update cache and get hidden state
130
+ past_key_values = outputs.past_key_values
131
+ current_hidden_state = outputs.hidden_states[-1][:, -1, :] # Last layer, last token
132
+
133
+ finish_reason = None
134
+ if next_token.item() == eos_token_id:
135
+ finish_reason = 'stop'
136
+ elif i == max_new_tokens - 1:
137
+ finish_reason = 'length'
138
+
139
+ # Yield result matching lmdeploy format
140
+ yield {
141
+ 'finish_reason': finish_reason,
142
+ 'hidden_state': current_hidden_state
143
+ }
144
+
145
+ if finish_reason:
146
+ break
147
+
148
+ # Prepare for next iteration
149
+ next_token_logits = outputs.logits[:, -1, :]
150
+ next_token = get_next_token(next_token_logits, generated_ids)
soprano/cli.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Soprano TTS Command Line Interface
4
+ """
5
+ import argparse
6
+ from soprano import SopranoTTS
7
+ from soprano.utils.streaming import play_stream
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(description='Soprano Text-to-Speech CLI')
11
+ parser.add_argument('text', help='Text to synthesize')
12
+ parser.add_argument('--output', '-o', default='output.wav',
13
+ help='Output audio file path (non-streaming only)')
14
+ parser.add_argument('--model-path', '-m',
15
+ help='Path to local model directory (optional)')
16
+ parser.add_argument('--device', '-d', default='auto',
17
+ choices=['auto', 'cuda', 'cpu', 'mps'],
18
+ help='Device to use for inference')
19
+ parser.add_argument('--backend', '-b', default='auto',
20
+ choices=['auto', 'transformers', 'lmdeploy'],
21
+ help='Backend to use for inference')
22
+ parser.add_argument('--cache-size', '-c', type=int, default=100,
23
+ help='Cache size in MB (for lmdeploy backend)')
24
+ parser.add_argument('--decoder-batch-size', '-bs', type=int, default=1,
25
+ help='Batch size when decoding audio')
26
+ parser.add_argument('--streaming', '-s', action='store_true',
27
+ help='Enable streaming playback to speakers')
28
+
29
+ args = parser.parse_args()
30
+
31
+ # Initialize TTS
32
+ tts = SopranoTTS(
33
+ backend=args.backend,
34
+ device=args.device,
35
+ cache_size_mb=args.cache_size,
36
+ decoder_batch_size=args.decoder_batch_size,
37
+ model_path=args.model_path
38
+ )
39
+
40
+ print(f"Generating speech for: '{args.text}'")
41
+ if args.streaming:
42
+ stream = tts.infer_stream(args.text, chunk_size=1)
43
+ play_stream(stream)
44
+ else:
45
+ tts.infer(args.text, out_path=args.output)
46
+ print(f"Audio saved to: {args.output}")
47
+
48
+ if __name__ == "__main__":
49
+ main()
soprano/server.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ from typing import Generator
5
+
6
+ import numpy as np
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.responses import Response
9
+ from scipy.io.wavfile import write
10
+ from torch import Tensor
11
+
12
+ from soprano.tts import SopranoTTS
13
+
14
+ # Load model at startup
15
+ tts = SopranoTTS(cache_size_mb = 100)
16
+
17
+ app = FastAPI(title="Soprano TTS API")
18
+
19
+ def _tensor_to_wav_bytes(tensor: Tensor) -> bytes:
20
+ """
21
+ Convert a 1D fp32 torch tensor to a WAV byte stream.
22
+ """
23
+ # convert to int16
24
+ audio_int16 = (np.clip(tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
25
+
26
+ wav_io = io.BytesIO()
27
+ write(wav_io, 32000, audio_int16) # 32kHz sample rate
28
+ wav_io.seek(0)
29
+ return wav_io.read()
30
+
31
+
32
+ @app.post("/v1/audio/speech")
33
+ async def create_speech(payload: dict):
34
+ """
35
+ Minimal implementation of OpenAI's Speech endpoint.
36
+ Fields:
37
+ - input: string - text to synthesize
38
+ - model, voice, etc. are accepted but ignored.
39
+ - response_format: str - ignored, only support wav.
40
+ """
41
+ text = payload.get("input")
42
+ if not isinstance(text, str) or not text.strip():
43
+ raise HTTPException(status_code=400, detail="`input` field must be a non-empty string.")
44
+
45
+ audio_tensor = tts.infer(text)
46
+ wav_bytes = _tensor_to_wav_bytes(audio_tensor)
47
+ return Response(content=wav_bytes, media_type="audio/wav", headers={"Content-Disposition": 'attachment; filename="speech.wav"'})
soprano/tts.py CHANGED
@@ -1,5 +1,7 @@
1
  from .vocos.decoder import SopranoDecoder
2
- from .utils.text import clean_text
 
 
3
  import torch
4
  import re
5
  from unidecode import unidecode
@@ -10,36 +12,43 @@ import time
10
 
11
 
12
  class SopranoTTS:
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def __init__(self,
14
  backend='auto',
15
- device='cuda',
16
- cache_size_mb=10,
17
- decoder_batch_size=1):
18
- RECOGNIZED_DEVICES = ['cuda']
19
- RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
20
- assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
21
- if backend == 'auto':
22
- if device == 'cpu':
23
- backend = 'transformers'
24
- else:
25
- try:
26
- import lmdeploy
27
- backend = 'lmdeploy'
28
- except ImportError:
29
- backend='transformers'
30
- print(f"Using backend {backend}.")
31
- assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
32
 
33
  if backend == 'lmdeploy':
34
  from .backends.lmdeploy import LMDeployModel
35
- self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
36
  elif backend == 'transformers':
37
  from .backends.transformers import TransformersModel
38
- self.pipeline = TransformersModel(device=device)
39
 
40
- self.decoder = SopranoDecoder().cuda()
41
- decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
42
- self.decoder.load_state_dict(torch.load(decoder_path))
 
 
 
 
 
43
  self.decoder_batch_size=decoder_batch_size
44
  self.RECEPTIVE_FIELD = 4 # Decoder receptive field
45
  self.TOKEN_SIZE = 2048 # Number of samples per audio token
@@ -55,7 +64,7 @@ class SopranoTTS:
55
  for text_idx, text in enumerate(texts):
56
  text = text.strip()
57
  cleaned_text = clean_text(text)
58
- sentences = re.split(r"(?<=[.!?])\s+", cleaned_text)
59
  processed = []
60
  for sentence in sentences:
61
  processed.append({
@@ -130,8 +139,8 @@ class SopranoTTS:
130
  N = len(lengths)
131
  for i in range(N):
132
  batch_hidden_states.append(torch.cat([
133
- torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
134
- hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
135
  ], dim=2))
136
  batch_hidden_states = torch.cat(batch_hidden_states)
137
  with torch.no_grad():
@@ -173,7 +182,7 @@ class SopranoTTS:
173
  if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
174
  if finished or chunk_counter == chunk_size:
175
  batch_hidden_states = torch.stack(hidden_states_buffer)
176
- inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
177
  with torch.no_grad():
178
  audio = self.decoder(inp)[0]
179
  if finished:
 
1
  from .vocos.decoder import SopranoDecoder
2
+ from .utils.text_normalizer import clean_text
3
+ from .utils.text_splitter import split_and_recombine_text
4
+ from .utils.auto_select import select_device, select_backend
5
  import torch
6
  import re
7
  from unidecode import unidecode
 
12
 
13
 
14
  class SopranoTTS:
15
+ """
16
+ Soprano Text-to-Speech model.
17
+
18
+ Args:
19
+ backend: Backend to use for inference. Options:
20
+ - 'auto' (default): Automatically select best backend. Tries lmdeploy first (fastest),
21
+ falls back to transformers. CPU always uses transformers.
22
+ - 'lmdeploy': Force use of LMDeploy (fastest, CUDA only)
23
+ - 'transformers': Force use of HuggingFace Transformers (slower, all devices)
24
+ device: Device to run inference on ('auto', 'cuda', 'cpu', 'mps')
25
+ cache_size_mb: Cache size in MB for lmdeploy backend
26
+ decoder_batch_size: Batch size for decoder
27
+ """
28
  def __init__(self,
29
  backend='auto',
30
+ device='auto',
31
+ cache_size_mb=100,
32
+ decoder_batch_size=1,
33
+ model_path=None):
34
+ device = select_device(device=device)
35
+ backend = select_backend(backend=backend, device=device)
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  if backend == 'lmdeploy':
38
  from .backends.lmdeploy import LMDeployModel
39
+ self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb, model_path=model_path)
40
  elif backend == 'transformers':
41
  from .backends.transformers import TransformersModel
42
+ self.pipeline = TransformersModel(device=device, model_path=model_path)
43
 
44
+ self.device = device
45
+ self.backend = backend
46
+ self.decoder = SopranoDecoder().to(device)
47
+ if model_path:
48
+ decoder_path = os.path.join(model_path, 'decoder.pth')
49
+ else:
50
+ decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
51
+ self.decoder.load_state_dict(torch.load(decoder_path, map_location=device))
52
  self.decoder_batch_size=decoder_batch_size
53
  self.RECEPTIVE_FIELD = 4 # Decoder receptive field
54
  self.TOKEN_SIZE = 2048 # Number of samples per audio token
 
64
  for text_idx, text in enumerate(texts):
65
  text = text.strip()
66
  cleaned_text = clean_text(text)
67
+ sentences = split_and_recombine_text(cleaned_text)
68
  processed = []
69
  for sentence in sentences:
70
  processed.append({
 
139
  N = len(lengths)
140
  for i in range(N):
141
  batch_hidden_states.append(torch.cat([
142
+ torch.zeros((1, 512, lengths[0]-lengths[i]), device=self.device),
143
+ hidden_states[idx+i].unsqueeze(0).transpose(1,2).to(self.device).to(torch.float32),
144
  ], dim=2))
145
  batch_hidden_states = torch.cat(batch_hidden_states)
146
  with torch.no_grad():
 
182
  if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
183
  if finished or chunk_counter == chunk_size:
184
  batch_hidden_states = torch.stack(hidden_states_buffer)
185
+ inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).to(self.device).to(torch.float32)
186
  with torch.no_grad():
187
  audio = self.decoder(inp)[0]
188
  if finished:
soprano/utils/auto_select.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ RECOGNIZED_DEVICES = ['auto', 'cuda', 'cpu', 'mps']
4
+ RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
5
+
6
+ def select_device(device='auto'):
7
+ if device == 'auto':
8
+ if torch.cuda.is_available():
9
+ device = 'cuda'
10
+ elif torch.backends.mps.is_available():
11
+ device = 'mps'
12
+ else:
13
+ device = 'cpu'
14
+
15
+ assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
16
+ print(f"Using device {device}")
17
+ return device
18
+
19
+ def select_backend(backend='auto', device='auto'):
20
+ if backend == 'auto':
21
+ if device == 'cpu':
22
+ backend = 'transformers'
23
+ else:
24
+ try:
25
+ import lmdeploy
26
+ backend = 'lmdeploy'
27
+ except ImportError:
28
+ backend = 'transformers'
29
+
30
+ assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
31
+ print(f"Using backend {backend}")
32
+ return backend
soprano/utils/streaming.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sounddevice as sd
2
+ import torch
3
+ import time
4
+
5
+
6
+ def play_stream(stream, sample_rate=32000):
7
+ """
8
+ Play streamed audio chunks to speakers in real time.
9
+ """
10
+ with sd.OutputStream(
11
+ samplerate=sample_rate,
12
+ channels=1,
13
+ dtype='float32',
14
+ blocksize=0
15
+ ) as out_stream:
16
+ start = time.time()
17
+ latency = None
18
+ first = True
19
+ for chunk in stream:
20
+ if first:
21
+ latency = time.time()-start
22
+ first = False
23
+
24
+ if isinstance(chunk, torch.Tensor):
25
+ chunk = chunk.detach().cpu()
26
+
27
+ # Ensure shape (N, 1)
28
+ if chunk.dim() == 1:
29
+ chunk = chunk.unsqueeze(1)
30
+ elif chunk.dim() == 2 and chunk.shape[0] == 1:
31
+ chunk = chunk.transpose(0, 1)
32
+
33
+ out_stream.write(chunk.numpy())
34
+ return latency
soprano/utils/text_normalizer.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Normalize input text to a format that Soprano recognizes.
3
+ Adapted from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/utils/tokenizer.py
4
+ """
5
+ import re
6
+
7
+ import inflect
8
+ from unidecode import unidecode
9
+
10
+
11
+ _inflect = inflect.engine()
12
+
13
+ ####################################################################################################
14
+ # Abbreviations
15
+
16
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
17
+ ('mrs', 'misess'),
18
+ ('ms', 'miss'),
19
+ ('mr', 'mister'),
20
+ ('dr', 'doctor'),
21
+ ('st', 'saint'),
22
+ ('co', 'company'),
23
+ ('jr', 'junior'),
24
+ ('maj', 'major'),
25
+ ('gen', 'general'),
26
+ ('drs', 'doctors'),
27
+ ('rev', 'reverend'),
28
+ ('lt', 'lieutenant'),
29
+ ('hon', 'honorable'),
30
+ ('sgt', 'sergeant'),
31
+ ('capt', 'captain'),
32
+ ('esq', 'esquire'),
33
+ ('ltd', 'limited'),
34
+ ('col', 'colonel'),
35
+ ('ft', 'fort'),
36
+ ]]
37
+ _cased_abbreviations = [(re.compile('\\b%s\\b' % x[0]), x[1]) for x in [
38
+ ('TTS', 'text to speech'),
39
+ ('Hz', 'hertz'),
40
+ ('kHz', 'kilohertz'),
41
+ ('KBs', 'kilobytes'),
42
+ ('KB', 'kilobyte'),
43
+ ('MBs', 'megabytes'),
44
+ ('MB', 'megabyte'),
45
+ ('GBs', 'gigabytes'),
46
+ ('GB', 'gigabyte'),
47
+ ('TBs', 'terabytes'),
48
+ ('TB', 'terabyte'),
49
+ ('APIs', 'a p i\'s'),
50
+ ('API', 'a p i'),
51
+ ('CLIs', 'c l i\'s'),
52
+ ('CLI', 'c l i'),
53
+ ('CPUs', 'c p u\'s'),
54
+ ('CPU', 'c p u'),
55
+ ('GPUs', 'g p u\'s'),
56
+ ('GPU', 'g p u'),
57
+ ('Ave', 'avenue'),
58
+ ('etc', 'et cetera'),
59
+ ('Mon', 'monday'),
60
+ ('Tues', 'tuesday'),
61
+ ('Wed', 'wednesday'),
62
+ ('Thurs', 'thursday'),
63
+ ('Fri', 'friday'),
64
+ ('Sat', 'saturday'),
65
+ ('Sun', 'sunday'),
66
+ ('and/or', 'and or'),
67
+ ]]
68
+
69
+ def expand_abbreviations(text):
70
+ for regex, replacement in _abbreviations + _cased_abbreviations:
71
+ text = re.sub(regex, replacement, text)
72
+ return text
73
+
74
+ ####################################################################################################
75
+ # Numbers
76
+
77
+ _num_prefix_re = re.compile(r'#\d')
78
+ _num_suffix_re = re.compile(r'\b\d+(K|M|B|T)\b', re.IGNORECASE)
79
+ _num_letter_split_re = re.compile(r'(\d[a-z]|[a-z]\d)', re.IGNORECASE)
80
+
81
+ _comma_number_re = re.compile(r'(\d[\d\,]+\d)')
82
+ _date_re = re.compile(r'(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])')
83
+ _phone_number_re = re.compile(r'(\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4})')
84
+ _time_re = re.compile(r'(\d\d?:\d\d(?::\d\d)?)')
85
+ _pounds_re = re.compile(r'£([\d\,]*\d+)')
86
+ _dollars_re = re.compile(r'\$([\d\.\,]*\d+)')
87
+ _decimal_number_re = re.compile(r'(\d+(?:\.\d+)+)')
88
+ _multiply_re = re.compile(r'(\d\s?\*\s?\d)')
89
+ _divide_re = re.compile(r'(\d\s?/\s?\d)')
90
+ _add_re = re.compile(r'(\d\s?\+\s?\d)')
91
+ _subtract_re = re.compile(r'(\d?\s?-\s?\d)') # also does negative numbers
92
+ _fraction_re = re.compile(r'(\d+(?:/\d+)+)')
93
+ _ordinal_re = re.compile(r'\d+(st|nd|rd|th)')
94
+ _number_re = re.compile(r'\d+')
95
+
96
+ def _expand_num_prefix(m):
97
+ match = m.group(0)
98
+ return f"number {match[1]}"
99
+
100
+ def _expand_num_suffix(m):
101
+ match = m.group(0)
102
+ if match[1].upper() == 'K': return f"{match[0]} thousand"
103
+ elif match[1].upper() == 'M': return f"{match[0]} million"
104
+ elif match[1].upper() == 'B': return f"{match[0]} billion"
105
+ elif match[1].upper() == 'T': return f"{match[0]} trillion"
106
+ return match # unexpected format
107
+
108
+ def _split_alphanumeric(m):
109
+ match = m.group(1)
110
+ return f"{match[0]} {match[1]}"
111
+
112
+ def _remove_commas(m):
113
+ return m.group(1).replace(',', '')
114
+
115
+ def _expand_date(m):
116
+ match = m.group(2)
117
+ match = re.split('[./-]', match)
118
+ return m.group(1) + ' dash '.join(match) + m.group(3)
119
+
120
+ def _expand_phone_number(m):
121
+ match = m.group(1)
122
+ match = re.sub(r'\D', '', match)
123
+ assert len(match) == 10
124
+ match = f"{' '.join(list(match[:3]))}, {' '.join(list(match[3:6]))}, {' '.join(list(match[6:]))}"
125
+ return match
126
+
127
+ def _expand_time(m):
128
+ match = m.group(1)
129
+ match = match.split(':')
130
+ if len(match) == 2:
131
+ hours, minutes = match
132
+ if minutes == '00':
133
+ if int(hours) == 0:
134
+ return '0'
135
+ elif int(hours) > 12: return f"{hours} minutes"
136
+ return f"{hours} o'clock"
137
+ elif minutes.startswith('0'):
138
+ minutes = f'oh {minutes[1:]}'
139
+ return f"{hours} {minutes}"
140
+ else:
141
+ hours, minutes, seconds = match
142
+ if int(hours) != 0:
143
+ return f"{hours} {'oh oh' if minutes == '00' else f'oh {minutes}' if minutes.startswith('0') else {minutes}} {'' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
144
+ elif minutes != '00':
145
+ return f"{minutes} {'oh oh' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
146
+ else:
147
+ return seconds
148
+
149
+ def _expand_dollars(m):
150
+ match = m.group(1)
151
+ parts = match.split('.')
152
+ if len(parts) > 2:
153
+ return match + ' dollars' # Unexpected format
154
+ dollars = int(parts[0]) if parts[0] else 0
155
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
156
+ if dollars and cents:
157
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
158
+ cent_unit = 'cent' if cents == 1 else 'cents'
159
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
160
+ elif dollars:
161
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
162
+ return '%s %s' % (dollars, dollar_unit)
163
+ elif cents:
164
+ cent_unit = 'cent' if cents == 1 else 'cents'
165
+ return '%s %s' % (cents, cent_unit)
166
+ else:
167
+ return 'zero dollars'
168
+
169
+ def _expand_decimal_point(m):
170
+ match = m.group(1)
171
+ match = match.split('.')
172
+ return match[0] + ' point ' + ' point '.join(' '.join(list(match[i])) for i in range(1, len(match)))
173
+
174
+ def _expand_fraction(m):
175
+ match = m.group(1)
176
+ match = match.split('/')
177
+ return ' over '.join(match) if len(match)==2 else ' slash '.join(match)
178
+
179
+ def _expand_multiply(m):
180
+ return ' times '.join(m.group(1).split('*'))
181
+
182
+ def _expand_divide(m):
183
+ return ' over '.join(m.group(1).split('/'))
184
+
185
+ def _expand_add(m):
186
+ return ' plus '.join(m.group(1).split('+'))
187
+
188
+ def _expand_subtract(m):
189
+ return ' minus '.join(m.group(1).split('-'))
190
+
191
+ def _expand_ordinal(m):
192
+ return _inflect.number_to_words(m.group(0), andword='')
193
+
194
+ def _expand_number(m):
195
+ num = int(m.group(0))
196
+ if num > 1000 and num < 3000:
197
+ if num == 2000:
198
+ return 'two thousand'
199
+ elif num > 2000 and num < 2010:
200
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
201
+ elif num % 100 == 0:
202
+ return _inflect.number_to_words(num // 100) + ' hundred'
203
+ else:
204
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
205
+ else:
206
+ return _inflect.number_to_words(num, andword='')
207
+
208
+ def normalize_numbers(text):
209
+ text = re.sub(_num_prefix_re, _expand_num_prefix, text)
210
+ text = re.sub(_num_suffix_re, _expand_num_suffix, text)
211
+ text = re.sub(_comma_number_re, _remove_commas, text)
212
+ text = re.sub(_date_re, _expand_date, text)
213
+ text = re.sub(_phone_number_re, _expand_phone_number, text)
214
+ text = re.sub(_time_re, _expand_time, text)
215
+ text = re.sub(_pounds_re, r'\1 pounds', text)
216
+ text = re.sub(_dollars_re, _expand_dollars, text)
217
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
218
+ text = re.sub(_multiply_re, _expand_multiply, text)
219
+ text = re.sub(_divide_re, _expand_divide, text)
220
+ text = re.sub(_add_re, _expand_add, text)
221
+ text = re.sub(_subtract_re, _expand_subtract, text)
222
+
223
+ text = re.sub(_fraction_re, _expand_fraction, text)
224
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
225
+ for _ in range(2): # need to do this twice to find all matches
226
+ text = re.sub(_num_letter_split_re, _split_alphanumeric, text)
227
+ text = re.sub(_number_re, _expand_number, text)
228
+ return text
229
+
230
+ ####################################################################################################
231
+ # Special characters & other patterns
232
+
233
+ _special_characters = [(re.compile(x[0]), x[1]) for x in [
234
+ ('@', ' at '),
235
+ ('&', ' and '),
236
+ ('%', ' percent '),
237
+ (':', '.'),
238
+ (';', ','),
239
+ (r'\+', ' plus '),
240
+ (r'\\', ' backslash '),
241
+ ('~', ' about '),
242
+ ('(^| )<3', ' heart '),
243
+ ('<=', ' less than or equal to '),
244
+ ('>=', ' greater than or equal to '),
245
+ ('<', ' less than '),
246
+ ('>', ' greater than '),
247
+ ('=', ' equals '),
248
+ ('/', ' slash '),
249
+ ('_', ' '),
250
+ (r'\*', ' '),
251
+ ]]
252
+ _link_header_re = re.compile(r'(https?://)')
253
+ _dash_re = re.compile(r'(. - .)')
254
+ _dot_re = re.compile(r'([A-Z]\.[A-Z])', re.IGNORECASE)
255
+ _parentheses_re = re.compile(r'[\(\[\{].*[\)\]\}](.|$)')
256
+
257
+ def expand_special_characters(text):
258
+ for regex, replacement in _special_characters:
259
+ text = re.sub(regex, replacement, text)
260
+ return text
261
+
262
+ def _expand_link_header(m):
263
+ return 'h t t p s colon slash slash '
264
+
265
+ def _expand_dash(m):
266
+ match = m.group(0)
267
+ return f"{match[0]}, {match[4]}"
268
+
269
+ def _expand_dot(m):
270
+ match = m.group(0)
271
+ return f"{match[0]} dot {match[2]}"
272
+
273
+ def _expand_parantheses(m):
274
+ match = m.group(0)
275
+ match = re.sub(r'[\(\[\{]', ', ', match)
276
+ match = re.sub(r'[\)\]\}][^$.!?,]', ', ', match)
277
+ match = re.sub(r'[\)\]\}]', '', match)
278
+ return match
279
+
280
+ def normalize_special(text):
281
+ text = re.sub(_link_header_re, _expand_link_header, text)
282
+ text = re.sub(_dash_re, _expand_dash, text)
283
+ text = re.sub(_dot_re, _expand_dot, text)
284
+ text = re.sub(_parentheses_re, _expand_parantheses, text)
285
+ return text
286
+
287
+ ####################################################################################################
288
+ # Misc
289
+
290
+ def lowercase(text):
291
+ return text.lower()
292
+
293
+ def convert_to_ascii(text):
294
+ return unidecode(text)
295
+
296
+ def normalize_newlines(text):
297
+ text = text.split('\n')
298
+ for i in range(len(text)):
299
+ text[i] = text[i].strip()
300
+ if not text[i]: continue
301
+ if text[i][-1] not in '.!?':
302
+ text[i] = f"{text[i]}."
303
+ return ' '.join(text)
304
+
305
+ def remove_unknown_characters(text):
306
+ text = re.sub(r"[^A-Za-z !\$%&'\*\+,-./0123456789<>\?_]", "", text)
307
+ text = re.sub(r"[<>/_+]", "", text)
308
+ return text
309
+
310
+ def collapse_whitespace(text):
311
+ text = re.sub(r'\s+', ' ', text)
312
+ text = re.sub(r' [.\?!,]', lambda m: m.group(0)[1], text)
313
+ return text.strip()
314
+
315
+ def dedup_punctuation(text):
316
+ text = re.sub(r"\.\.\.+", "[ELLIPSIS]", text)
317
+ text = re.sub(r",+", ",", text)
318
+ text = re.sub(r"[\.,]*\.[\.,]*", ".", text)
319
+ text = re.sub(r"[\.,!]*![\.,!]*", "!", text)
320
+ text = re.sub(r"[\.,!\?]*\?[\.,!\?]*", "?", text)
321
+ text = re.sub(r"\[ELLIPSIS\]", "...", text)
322
+ return text
323
+
324
+ def clean_text(text):
325
+ text = convert_to_ascii(text)
326
+ text = normalize_newlines(text)
327
+ text = normalize_numbers(text)
328
+ text = normalize_special(text)
329
+ text = expand_abbreviations(text)
330
+ text = expand_special_characters(text)
331
+ text = lowercase(text)
332
+ text = remove_unknown_characters(text)
333
+ text = collapse_whitespace(text)
334
+ text = dedup_punctuation(text)
335
+ return text
336
+
337
+
338
+ if __name__ == '__main__':
339
+ print(clean_text('1,2,3,456,176'))
340
+ print(clean_text('123,456,789'))
341
+ print(clean_text('123,456,789th'))
342
+ print(clean_text('123-456-7890'))
343
+ print(clean_text('111-111-1111'))
344
+ print(clean_text('(111) 111-1111'))
345
+ print(clean_text('A(111) 111-1111'))
346
+ print(clean_text('A (111) 111-1111'))
347
+ print(clean_text('$2.47'))
348
+ print(clean_text('$247'))
349
+ print(clean_text('$0.27'))
350
+ print(clean_text('$1.00'))
351
+ print(clean_text('£20'))
352
+ for i in range(1990, 2030):
353
+ print(clean_text(str(i)))
354
+ print(clean_text('2656'))
355
+ print(clean_text('1024'))
356
+ print(clean_text('2.47023'))
357
+ print(clean_text('20.47023'))
358
+ print(clean_text('1.17.1.1'))
359
+ print(clean_text('111.111.1111'))
360
+ print(clean_text('1/1/2025'))
361
+ print(clean_text('1-1-2025'))
362
+ print(clean_text('1-1-25'))
363
+ print(clean_text('A 1/1/11 A'))
364
+ print(clean_text('A 1/1 A'))
365
+ print(clean_text('1/1'))
366
+ print(clean_text('1/10'))
367
+ print(clean_text('1/1/10'))
368
+ print(clean_text('11/1/1/10'))
369
+
370
+ print(clean_text('0:00'))
371
+ print(clean_text('12:00'))
372
+ print(clean_text('13:00'))
373
+ print(clean_text('8:00'))
374
+ print(clean_text('8:05'))
375
+ print(clean_text('8:15'))
376
+ print(clean_text('0:00:00'))
377
+ print(clean_text('00:01:10'))
378
+ print(clean_text('00:10:01'))
379
+ print(clean_text('01:01:01'))
380
+ print(clean_text('00:01:00'))
381
+ print(clean_text('01:00:00'))
382
+
383
+ print(clean_text('-1 + 2 * 3 - 4 / 5'))
384
+ print(clean_text('-1+2*3-5/4/25'))
385
+
386
+ print(clean_text('100x1'))
387
+ print(clean_text('100k'))
388
+ print(clean_text('100m'))
389
+ print(clean_text('100b'))
390
+ print(clean_text('100t'))
391
+
392
+ print(clean_text('#1'))
393
+
394
+ print(clean_text('12:00'))
395
+ print(clean_text('11:59'))
396
+ print(clean_text('01:00'))
397
+ print(clean_text('0100'))
398
+
399
+ print(clean_text('1st 2nd 3rd 4th'))
400
+ print(clean_text('1K 1M 1B 1T 1K1M1B1T'))
401
+ print(clean_text('and/or'))
soprano/utils/text_splitter.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/utils/text.py
3
+ """
4
+ import re
5
+
6
+
7
+ def split_and_recombine_text(text, desired_length=30, max_length=300):
8
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
9
+ # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
10
+ text = re.sub(r'\n\n+', '\n', text)
11
+ text = re.sub(r'\s+', ' ', text)
12
+ text = re.sub(r'[“”]', '"', text)
13
+
14
+ rv = []
15
+ in_quote = False
16
+ current = ""
17
+ split_pos = []
18
+ pos = -1
19
+ end_pos = len(text) - 1
20
+
21
+ def seek(delta):
22
+ nonlocal pos, in_quote, current
23
+ is_neg = delta < 0
24
+ for _ in range(abs(delta)):
25
+ if is_neg:
26
+ pos -= 1
27
+ current = current[:-1]
28
+ else:
29
+ pos += 1
30
+ current += text[pos]
31
+ if text[pos] == '"':
32
+ in_quote = not in_quote
33
+ return text[pos]
34
+
35
+ def peek(delta):
36
+ p = pos + delta
37
+ return text[p] if p < end_pos and p >= 0 else ""
38
+
39
+ def commit():
40
+ nonlocal rv, current, split_pos
41
+ rv.append(current)
42
+ current = ""
43
+ split_pos = []
44
+
45
+ while pos < end_pos:
46
+ c = seek(1)
47
+ # do we need to force a split?
48
+ if len(current) >= max_length:
49
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
50
+ # we have at least one sentence and we are over half the desired length, seek back to the last split
51
+ d = pos - split_pos[-1]
52
+ seek(-d)
53
+ else:
54
+ # no full sentences, seek back until we are not in the middle of a word and split there
55
+ while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
56
+ c = seek(-1)
57
+ commit()
58
+ # check for sentence boundaries
59
+ elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
60
+ # seek forward if we have consecutive boundary markers but still within the max length
61
+ while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
62
+ c = seek(1)
63
+ split_pos.append(pos)
64
+ if len(current) >= desired_length:
65
+ commit()
66
+ # treat end of quote as a boundary if its followed by a space or newline
67
+ elif in_quote and peek(1) == '"' and peek(2) in '\n ':
68
+ seek(2)
69
+ split_pos.append(pos)
70
+ rv.append(current)
71
+
72
+ # clean up, remove lines with only whitespace or punctuation
73
+ rv = [s.strip() for s in rv]
74
+ rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
75
+
76
+ return rv
soprano/vocos/spectral_ops.py CHANGED
@@ -24,7 +24,7 @@ class ISTFT(nn.Module):
24
  self.n_fft = n_fft
25
  self.hop_length = hop_length
26
  self.win_length = win_length
27
- window = torch.hann_window(win_length).to('cuda')
28
  self.register_buffer("window", window)
29
 
30
  def forward(self, spec: torch.Tensor) -> torch.Tensor:
 
24
  self.n_fft = n_fft
25
  self.hop_length = hop_length
26
  self.win_length = win_length
27
+ window = torch.hann_window(win_length)
28
  self.register_buffer("window", window)
29
 
30
  def forward(self, spec: torch.Tensor) -> torch.Tensor:
soprano/webui.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Web Interface for Soprano TTS
4
+ """
5
+
6
+ import argparse
7
+ import socket
8
+ import time
9
+ import gradio as gr
10
+ import numpy as np
11
+ from soprano import SopranoTTS
12
+ from soprano.utils.streaming import play_stream
13
+
14
+
15
+ parser = argparse.ArgumentParser(description='Soprano Text-to-Speech Gradio WebUI')
16
+ parser.add_argument('--model-path', '-m',
17
+ help='Path to local model directory (optional)')
18
+ parser.add_argument('--device', '-d', default='auto',
19
+ choices=['auto', 'cuda', 'cpu', 'mps'],
20
+ help='Device to use for inference')
21
+ parser.add_argument('--backend', '-b', default='auto',
22
+ choices=['auto', 'transformers', 'lmdeploy'],
23
+ help='Backend to use for inference')
24
+ parser.add_argument('--cache-size', '-c', type=int, default=100,
25
+ help='Cache size in MB (for lmdeploy backend)')
26
+ parser.add_argument('--decoder-batch-size', '-bs', type=int, default=1,
27
+ help='Batch size when decoding audio')
28
+ args = parser.parse_args()
29
+
30
+ # Initialize model
31
+ print("Loading Soprano TTS model...")
32
+ model = SopranoTTS(
33
+ backend=args.backend,
34
+ device=args.device,
35
+ cache_size_mb=args.cache_size,
36
+ decoder_batch_size=args.decoder_batch_size,
37
+ model_path=args.model_path
38
+ )
39
+ device = model.device
40
+ backend = model.backend
41
+ print("Model loaded successfully!")
42
+
43
+ SAMPLE_RATE = 32000
44
+
45
+
46
+ def generate_speech(
47
+ text: str,
48
+ temperature: float,
49
+ top_p: float,
50
+ repetition_penalty: float,
51
+ chunk_size: int,
52
+ streaming: bool,
53
+ ):
54
+ if not text.strip():
55
+ yield None, "Please enter some text to generate speech."
56
+ return
57
+
58
+ try:
59
+ if streaming:
60
+ stream = model.infer_stream(
61
+ text,
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ repetition_penalty=repetition_penalty,
65
+ chunk_size=chunk_size,
66
+ )
67
+ yield None, "⏳ Streaming..."
68
+
69
+ latency = play_stream(stream)
70
+
71
+ yield None, (
72
+ f"✓ Streaming complete | "
73
+ f"{latency*1000:.2f} ms latency"
74
+ )
75
+ return
76
+
77
+ start_time = time.perf_counter()
78
+
79
+ audio = model.infer(
80
+ text,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ repetition_penalty=repetition_penalty,
84
+ )
85
+
86
+ gen_time = time.perf_counter() - start_time
87
+
88
+ audio_np = audio.cpu().numpy()
89
+ audio_int16 = (audio_np * 32767).astype(np.int16)
90
+
91
+ audio_seconds = len(audio_np) / SAMPLE_RATE
92
+ rtf = audio_seconds / gen_time if gen_time > 0 else float("inf")
93
+
94
+ status = (
95
+ f"✓ Generated {audio_seconds:.2f} s audio | "
96
+ f"Generation time: {gen_time:.3f} s "
97
+ f"({rtf:.2f}x realtime)"
98
+ )
99
+
100
+ yield (SAMPLE_RATE, audio_int16), status
101
+ return
102
+
103
+ except Exception as e:
104
+ yield None, f"✗ Error: {str(e)}"
105
+
106
+
107
+ # Create Gradio interface
108
+ with gr.Blocks(title="Soprano TTS") as demo:
109
+ gr.Markdown(
110
+ f"""# 🗣️ Soprano TTS
111
+
112
+ <div align="center">
113
+ <img width="300" height="300" alt="soprano-github" src="https://github.com/user-attachments/assets/4d612eac-23b8-44e6-8c59-d7ac14ebafd1" />
114
+ </div>
115
+
116
+ **Device:** {device.upper()} | **Backend:** {backend}
117
+
118
+ **Model Weights:** https://huggingface.co/ekwek/Soprano-80M
119
+ **Model Demo:** https://huggingface.co/spaces/ekwek/Soprano-TTS
120
+ **GitHub:** https://github.com/ekwek1/soprano
121
+ """
122
+ )
123
+ with gr.Row():
124
+ with gr.Column(scale=2):
125
+ text_input = gr.Textbox(
126
+ label="Text to Synthesize",
127
+ placeholder="Enter text here...",
128
+ value="Soprano is an extremely lightweight text to speech model designed to produce highly realistic speech at unprecedented speed.",
129
+ lines=5,
130
+ max_lines=10,
131
+ )
132
+ streaming = gr.Checkbox(
133
+ label="Stream Audio",
134
+ value=False,
135
+ info="Note: This bypasses the Gradio interface and streams audio directly to your speaker."
136
+ )
137
+ with gr.Accordion("Advanced Settings", open=False):
138
+ temperature = gr.Slider(
139
+ minimum=0.1,
140
+ maximum=1.5,
141
+ value=0.3,
142
+ step=0.05,
143
+ label="Temperature",
144
+ )
145
+ top_p = gr.Slider(
146
+ minimum=0.5,
147
+ maximum=1.0,
148
+ value=0.95,
149
+ step=0.05,
150
+ label="Top P",
151
+ )
152
+ repetition_penalty = gr.Slider(
153
+ minimum=1.0,
154
+ maximum=2.0,
155
+ value=1.2,
156
+ step=0.1,
157
+ label="Repetition Penalty",
158
+ )
159
+ chunk_size = gr.Slider(
160
+ minimum=1,
161
+ maximum=10,
162
+ value=1,
163
+ step=1,
164
+ precision=0,
165
+ label="Chunk Size (Streaming only)",
166
+ )
167
+ generate_btn = gr.Button("Generate Speech", variant="primary", size="lg")
168
+ with gr.Column(scale=1):
169
+ audio_output = gr.Audio(
170
+ label="Generated Speech",
171
+ type="numpy",
172
+ autoplay=True,
173
+ )
174
+ status_output = gr.Textbox(
175
+ label="Status",
176
+ interactive=False,
177
+ lines=3,
178
+ max_lines=10
179
+ )
180
+ gr.Examples(
181
+ examples=[
182
+ ["Soprano is an extremely lightweight text to speech model.", 0.3, 0.95, 1.2],
183
+ ["Artificial intelligence is transforming the world.", 0.5, 0.90, 1.2],
184
+ ["I'm so excited, I can't even wait!", 0.3, 0.95, 1.2],
185
+ ["Why don't you go ahead and try it?", 0.3, 0.95, 1.2],
186
+ ],
187
+ inputs=[text_input, temperature, top_p, repetition_penalty],
188
+ label="Example Prompts",
189
+ )
190
+ generate_btn.click(
191
+ fn=generate_speech,
192
+ inputs=[text_input, temperature, top_p, repetition_penalty, chunk_size, streaming],
193
+ outputs=[audio_output, status_output],
194
+ )
195
+ gr.Markdown(
196
+ f"""
197
+ ### Usage tips:
198
+
199
+ - Soprano works best when each sentence is between 2 and 15 seconds long.
200
+ - Although Soprano recognizes numbers and some special characters, it occasionally mispronounces them.
201
+ Best results can be achieved by converting these into their phonetic form.
202
+ (1+1 -> one plus one, etc)
203
+ - If Soprano produces unsatisfactory results, you can easily regenerate it for a new, potentially better generation.
204
+ You may also change the sampling settings for more varied results.
205
+ - Avoid improper grammar such as not using contractions, multiple spaces, etc.
206
+ """
207
+ )
208
+
209
+
210
+ def find_free_port(start_port=7860, max_tries=100):
211
+ for port in range(start_port, start_port + max_tries):
212
+ try:
213
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
214
+ s.bind(("", port))
215
+ return port
216
+ except OSError:
217
+ continue
218
+ raise OSError("Could not find a free port")
219
+
220
+ def main():
221
+ # Start Gradio interface
222
+ port = find_free_port(7860)
223
+ print(f"Starting Gradio interface on port {port}")
224
+ demo.launch(
225
+ server_name="127.0.0.1",
226
+ server_port=port,
227
+ share=False,
228
+ theme=gr.themes.Soft(primary_hue="green"),
229
+ css="""
230
+ a {
231
+ color: var(--primary-600);
232
+ }
233
+ a:hover {
234
+ color: var(--primary-700);
235
+ }
236
+ """
237
+ )
238
+
239
+ if __name__ == "__main__":
240
+ main()