import os import torch import spaces import gradio as gr import soundfile as sf from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # --- CONFIGURATION --- MODEL_ID = "maya-research/Veena" SNAC_MODEL_ID = "hubertsiuzdak/snac_24khz" VALID_KEY = os.environ.get("MY_API_KEY") # Set this in HF Space Secrets # Token Offsets for Veena START_OF_SPEECH_TOKEN = 128257 END_OF_SPEECH_TOKEN = 128258 START_OF_HUMAN_TOKEN = 128259 END_OF_HUMAN_TOKEN = 128260 START_OF_AI_TOKEN = 128261 END_OF_AI_TOKEN = 128262 AUDIO_CODE_BASE_OFFSET = 128266 # --- MODEL LOADING --- # 4-bit config allows it to run on smaller/shared GPUs quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=quant_config, device_map="auto" ) snac_model = SNAC.from_pretrained(SNAC_MODEL_ID).eval().to("cuda" if torch.cuda.is_available() else "cpu") def decode_audio(tokens): """Converts Veena's tokens into a WAV file""" snac_tokens = [t for t in tokens if t >= AUDIO_CODE_BASE_OFFSET] if not snac_tokens or len(snac_tokens) % 7 != 0: return None codes_lvl = [[] for _ in range(3)] # De-interleave based on Veena's 7-token frame structure for i in range(0, len(snac_tokens), 7): codes_lvl[0].append(snac_tokens[i] - AUDIO_CODE_BASE_OFFSET) codes_lvl[1].extend([snac_tokens[i+1]- (AUDIO_CODE_BASE_OFFSET + 4096), snac_tokens[i+2]- (AUDIO_CODE_BASE_OFFSET + 8192)]) codes_lvl[2].extend([snac_tokens[i+3]- (AUDIO_CODE_BASE_OFFSET + 12288), snac_tokens[i+4]- (AUDIO_CODE_BASE_OFFSET + 16384), snac_tokens[i+5]- (AUDIO_CODE_BASE_OFFSET + 20480), snac_tokens[i+6]- (AUDIO_CODE_BASE_OFFSET + 24576)]) codes = [torch.tensor([c]).to(snac_model.device) for c in codes_lvl] with torch.no_grad(): audio_values = snac_model.decode(codes) return audio_values.cpu().numpy().squeeze() @spaces.GPU def generate_veena_speech(text, api_key, speaker="kavya"): # Security check for n8n if api_key != VALID_KEY: raise gr.Error("Invalid API Key") # Format prompt for Veena prompt = [START_OF_HUMAN_TOKEN] + tokenizer.encode(f" {text}") + [END_OF_HUMAN_TOKEN, START_OF_AI_TOKEN] input_ids = torch.tensor([prompt]).to(model.device) with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=1024, do_sample=True, eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN] ) audio_data = decode_audio(output[0].tolist()) if audio_data is not None: output_path = "output.wav" sf.write(output_path, audio_data, 24000) return output_path return None # --- GRADIO INTERFACE --- demo = gr.Interface( fn=generate_veena_speech, inputs=[ gr.Textbox(label="Text to Speak"), gr.Textbox(label="API Key", type="password"), gr.Dropdown(choices=["kavya", "agastya", "maitri", "vinaya"], value="kavya", label="Speaker") ], outputs=gr.Audio(label="Generated Audio"), api_name="predict" ) demo.launch()