| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from snac import SNAC |
| import soundfile as sf |
|
|
| |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| "akh99/veena-hinglish", |
| quantization_config=quantization_config, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained("akh99/veena-hinglish", trust_remote_code=True) |
|
|
| |
| snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to("cpu") |
|
|
| |
| START_OF_SPEECH_TOKEN = 128257 |
| END_OF_SPEECH_TOKEN = 128258 |
| START_OF_HUMAN_TOKEN = 128259 |
| END_OF_HUMAN_TOKEN = 128260 |
| START_OF_AI_TOKEN = 128261 |
| END_OF_AI_TOKEN = 128262 |
| AUDIO_CODE_BASE_OFFSET = 128266 |
|
|
| |
| speakers = ["kavya", "agastya", "maitri", "vinaya"] |
|
|
| def generate_speech(text, speaker="kavya", temperature=0.01, top_p=0.3): |
| """Generate speech from text using specified speaker voice""" |
|
|
| |
| prompt = f"<spk_{speaker}> {text}" |
| prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) |
|
|
| |
| input_tokens = [ |
| START_OF_HUMAN_TOKEN, |
| *prompt_tokens, |
| END_OF_HUMAN_TOKEN, |
| START_OF_AI_TOKEN, |
| START_OF_SPEECH_TOKEN |
| ] |
|
|
| input_ids = torch.tensor([input_tokens], device=model.device) |
|
|
| |
| max_tokens = min(int(len(text) * 1.3) * 7 + 21, 700) |
|
|
| |
| with torch.no_grad(): |
| output = model.generate( |
| input_ids, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=1.05, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN] |
| ) |
|
|
| |
| generated_ids = output[0][len(input_tokens):].tolist() |
| snac_tokens = [ |
| token_id for token_id in generated_ids |
| if AUDIO_CODE_BASE_OFFSET <= token_id < (AUDIO_CODE_BASE_OFFSET + 7 * 4096) |
| ] |
|
|
| if not snac_tokens: |
| raise ValueError("No audio tokens generated") |
|
|
| |
| audio = decode_snac_tokens(snac_tokens, snac_model) |
| return audio |
|
|
| def decode_snac_tokens(snac_tokens, snac_model): |
| """De-interleave and decode SNAC tokens to audio""" |
| if not snac_tokens or len(snac_tokens) % 7 != 0: |
| return None |
|
|
| |
| snac_device = next(snac_model.parameters()).device |
|
|
| |
| codes_lvl = [[] for _ in range(3)] |
| llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)] |
|
|
| for i in range(0, len(snac_tokens), 7): |
| |
| codes_lvl[0].append(snac_tokens[i] - llm_codebook_offsets[0]) |
| |
| codes_lvl[1].append(snac_tokens[i+1] - llm_codebook_offsets[1]) |
| codes_lvl[1].append(snac_tokens[i+4] - llm_codebook_offsets[4]) |
| |
| codes_lvl[2].append(snac_tokens[i+2] - llm_codebook_offsets[2]) |
| codes_lvl[2].append(snac_tokens[i+3] - llm_codebook_offsets[3]) |
| codes_lvl[2].append(snac_tokens[i+5] - llm_codebook_offsets[5]) |
| codes_lvl[2].append(snac_tokens[i+6] - llm_codebook_offsets[6]) |
|
|
| |
| hierarchical_codes = [] |
| for lvl_codes in codes_lvl: |
| tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=snac_device).unsqueeze(0) |
| if torch.any((tensor < 0) | (tensor > 4095)): |
| raise ValueError("Invalid SNAC token values") |
| hierarchical_codes.append(tensor) |
|
|
| |
| with torch.no_grad(): |
| audio_hat = snac_model.decode(hierarchical_codes) |
|
|
| return audio_hat.squeeze().clamp(-1, 1).cpu().numpy() |
|
|
| |
| |
| ''' |
| text_mixed = "Laptop switch on nahi ho raha shayad charging problem hai." |
| audio = generate_speech(text_mixed, speaker="kavya") |
| sf.write("output_hinglish.wav", audio, 24000) |
| ''' |
|
|
| def interface_fn(text, speaker): |
| audio = generate_speech(text, speaker=speaker) |
| |
| output_path = "output.wav" |
| sf.write(output_path, audio, 24000) |
| return output_path |
|
|
| |
| demo = gr.Interface( |
| fn=interface_fn, |
| inputs=[gr.Textbox(label="Text"), gr.Dropdown(choices=speakers, label="Speaker")], |
| outputs=gr.Audio(label="Generated Audio") |
| ) |
|
|
| demo.launch() |