Spaces:
Running
Running
| import os | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| import soundfile as sf | |
| from snac import SNAC | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| # --- CONFIGURATION --- | |
| MODEL_ID = "maya-research/Veena" | |
| SNAC_MODEL_ID = "hubertsiuzdak/snac_24khz" | |
| VALID_KEY = os.environ.get("MY_API_KEY") # Set this in HF Space Secrets | |
| # Token Offsets for Veena | |
| START_OF_SPEECH_TOKEN = 128257 | |
| END_OF_SPEECH_TOKEN = 128258 | |
| START_OF_HUMAN_TOKEN = 128259 | |
| END_OF_HUMAN_TOKEN = 128260 | |
| START_OF_AI_TOKEN = 128261 | |
| END_OF_AI_TOKEN = 128262 | |
| AUDIO_CODE_BASE_OFFSET = 128266 | |
| # --- MODEL LOADING --- | |
| # 4-bit config allows it to run on smaller/shared GPUs | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=quant_config, | |
| device_map="auto" | |
| ) | |
| snac_model = SNAC.from_pretrained(SNAC_MODEL_ID).eval().to("cuda" if torch.cuda.is_available() else "cpu") | |
| def decode_audio(tokens): | |
| """Converts Veena's tokens into a WAV file""" | |
| snac_tokens = [t for t in tokens if t >= AUDIO_CODE_BASE_OFFSET] | |
| if not snac_tokens or len(snac_tokens) % 7 != 0: | |
| return None | |
| codes_lvl = [[] for _ in range(3)] | |
| # De-interleave based on Veena's 7-token frame structure | |
| for i in range(0, len(snac_tokens), 7): | |
| codes_lvl[0].append(snac_tokens[i] - AUDIO_CODE_BASE_OFFSET) | |
| codes_lvl[1].extend([snac_tokens[i+1]- (AUDIO_CODE_BASE_OFFSET + 4096), snac_tokens[i+2]- (AUDIO_CODE_BASE_OFFSET + 8192)]) | |
| codes_lvl[2].extend([snac_tokens[i+3]- (AUDIO_CODE_BASE_OFFSET + 12288), snac_tokens[i+4]- (AUDIO_CODE_BASE_OFFSET + 16384), | |
| snac_tokens[i+5]- (AUDIO_CODE_BASE_OFFSET + 20480), snac_tokens[i+6]- (AUDIO_CODE_BASE_OFFSET + 24576)]) | |
| codes = [torch.tensor([c]).to(snac_model.device) for c in codes_lvl] | |
| with torch.no_grad(): | |
| audio_values = snac_model.decode(codes) | |
| return audio_values.cpu().numpy().squeeze() | |
| def generate_veena_speech(text, api_key, speaker="kavya"): | |
| # Security check for n8n | |
| if api_key != VALID_KEY: | |
| raise gr.Error("Invalid API Key") | |
| # Format prompt for Veena | |
| prompt = [START_OF_HUMAN_TOKEN] + tokenizer.encode(f"<spk_{speaker}> {text}") + [END_OF_HUMAN_TOKEN, START_OF_AI_TOKEN] | |
| input_ids = torch.tensor([prompt]).to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN] | |
| ) | |
| audio_data = decode_audio(output[0].tolist()) | |
| if audio_data is not None: | |
| output_path = "output.wav" | |
| sf.write(output_path, audio_data, 24000) | |
| return output_path | |
| return None | |
| # --- GRADIO INTERFACE --- | |
| demo = gr.Interface( | |
| fn=generate_veena_speech, | |
| inputs=[ | |
| gr.Textbox(label="Text to Speak"), | |
| gr.Textbox(label="API Key", type="password"), | |
| gr.Dropdown(choices=["kavya", "agastya", "maitri", "vinaya"], value="kavya", label="Speaker") | |
| ], | |
| outputs=gr.Audio(label="Generated Audio"), | |
| api_name="predict" | |
| ) | |
| demo.launch() |