File size: 4,951 Bytes
cf7ef74 762741a cf7ef74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from snac import SNAC
import soundfile as sf
# Model configuration for 4-bit inference
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"akh99/veena-hinglish",
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("akh99/veena-hinglish", trust_remote_code=True)
# Initialize SNAC decoder
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to("cpu")
# Control token IDs (fixed for Veena)
START_OF_SPEECH_TOKEN = 128257
END_OF_SPEECH_TOKEN = 128258
START_OF_HUMAN_TOKEN = 128259
END_OF_HUMAN_TOKEN = 128260
START_OF_AI_TOKEN = 128261
END_OF_AI_TOKEN = 128262
AUDIO_CODE_BASE_OFFSET = 128266
# Available speakers
speakers = ["kavya", "agastya", "maitri", "vinaya"]
def generate_speech(text, speaker="kavya", temperature=0.01, top_p=0.3):
"""Generate speech from text using specified speaker voice"""
# Prepare input with speaker token
prompt = f"<spk_{speaker}> {text}"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
# Construct full sequence: [HUMAN] <spk_speaker> text [/HUMAN] [AI] [SPEECH]
input_tokens = [
START_OF_HUMAN_TOKEN,
*prompt_tokens,
END_OF_HUMAN_TOKEN,
START_OF_AI_TOKEN,
START_OF_SPEECH_TOKEN
]
input_ids = torch.tensor([input_tokens], device=model.device)
# Calculate max tokens based on text length
max_tokens = min(int(len(text) * 1.3) * 7 + 21, 700)
# Generate audio tokens
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.05,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN]
)
# Extract SNAC tokens
generated_ids = output[0][len(input_tokens):].tolist()
snac_tokens = [
token_id for token_id in generated_ids
if AUDIO_CODE_BASE_OFFSET <= token_id < (AUDIO_CODE_BASE_OFFSET + 7 * 4096)
]
if not snac_tokens:
raise ValueError("No audio tokens generated")
# Decode audio
audio = decode_snac_tokens(snac_tokens, snac_model)
return audio
def decode_snac_tokens(snac_tokens, snac_model):
"""De-interleave and decode SNAC tokens to audio"""
if not snac_tokens or len(snac_tokens) % 7 != 0:
return None
# Get the device of the SNAC model. Fixed by Shresth to run on colab notebook :)
snac_device = next(snac_model.parameters()).device
# De-interleave tokens into 3 hierarchical levels
codes_lvl = [[] for _ in range(3)]
llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]
for i in range(0, len(snac_tokens), 7):
# Level 0: Coarse (1 token)
codes_lvl[0].append(snac_tokens[i] - llm_codebook_offsets[0])
# Level 1: Medium (2 tokens)
codes_lvl[1].append(snac_tokens[i+1] - llm_codebook_offsets[1])
codes_lvl[1].append(snac_tokens[i+4] - llm_codebook_offsets[4])
# Level 2: Fine (4 tokens)
codes_lvl[2].append(snac_tokens[i+2] - llm_codebook_offsets[2])
codes_lvl[2].append(snac_tokens[i+3] - llm_codebook_offsets[3])
codes_lvl[2].append(snac_tokens[i+5] - llm_codebook_offsets[5])
codes_lvl[2].append(snac_tokens[i+6] - llm_codebook_offsets[6])
# Convert to tensors for SNAC decoder
hierarchical_codes = []
for lvl_codes in codes_lvl:
tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0)
if torch.any((tensor < 0) | (tensor > 4095)):
raise ValueError("Invalid SNAC token values")
hierarchical_codes.append(tensor)
# Decode with SNAC
with torch.no_grad():
audio_hat = snac_model.decode(hierarchical_codes)
return audio_hat.squeeze().clamp(-1, 1).cpu().numpy()
# --- Example Usage ---
# Hinglish
'''
text_mixed = "Laptop switch on nahi ho raha shayad charging problem hai."
audio = generate_speech(text_mixed, speaker="kavya")
sf.write("output_hinglish.wav", audio, 24000)
'''
def interface_fn(text, speaker):
audio = generate_speech(text, speaker=speaker)
# Save to a temporary file for Gradio to play
output_path = "output.wav"
sf.write(output_path, audio, 24000)
return output_path
# Create the Gradio interface
demo = gr.Interface(
fn=interface_fn,
inputs=[gr.Textbox(label="Text"), gr.Dropdown(choices=speakers, label="Speaker")],
outputs=gr.Audio(label="Generated Audio")
)
demo.launch() |