Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import sys | |
| # Fix OMP_NUM_THREADS issue before any imports | |
| os.environ["OMP_NUM_THREADS"] = "4" | |
| # Install dependencies programmatically to avoid conflicts | |
| def setup_dependencies(): | |
| try: | |
| # Check if already installed | |
| if os.path.exists('/tmp/deps_installed'): | |
| return | |
| print("Installing transformers dev version...") | |
| subprocess.check_call([ | |
| sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", | |
| "git+https://github.com/huggingface/transformers.git" | |
| ]) | |
| # Mark as installed | |
| with open('/tmp/deps_installed', 'w') as f: | |
| f.write('done') | |
| except Exception as e: | |
| print(f"Dependencies setup error: {e}") | |
| # Run setup | |
| setup_dependencies() | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import librosa | |
| import gradio as gr | |
| from nemo.collections.tts.models import AudioCodecModel | |
| import os | |
| import sys | |
| # Add the parent directory to sys.path to import kanitts | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from kanitts import Config | |
| # Load configuration | |
| config = Config.default() | |
| # Load KaniTTS model and tokenizer | |
| kani_model_id = config.model.model_name | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| kani_model_id, | |
| trust_remote_code=True, | |
| use_fast=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| kani_model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # Load Nemo codec | |
| nemo_model_id = config.audio.nemo_model_name | |
| nemo_codec = AudioCodecModel.from_pretrained(nemo_model_id).eval().cuda() | |
| # Load Whisper for transcription | |
| whisper_turbo_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3-turbo", | |
| torch_dtype=torch.float16, | |
| device='cuda', | |
| ) | |
| # KaniTTS token IDs from config | |
| tokens = config.tokens | |
| SOH_ID = tokens.start_of_human | |
| EOH_ID = tokens.end_of_human | |
| SOA_ID = tokens.start_of_ai | |
| EOA_ID = tokens.end_of_ai | |
| SOT_ID = tokens.start_of_text | |
| EOT_ID = tokens.end_of_text | |
| SOS_ID = tokens.start_of_speech | |
| EOS_ID = tokens.end_of_speech | |
| def tokenize_audio(waveform, target_sample_rate=22050): | |
| """ | |
| Tokenize audio using Nemo codec for KaniTTS. | |
| """ | |
| # Ensure correct sample rate | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono if stereo | |
| # Resample if needed (simplified - in practice you'd use proper resampling) | |
| waveform = waveform.to(dtype=torch.float32) | |
| # Ensure we have the right shape: [batch, samples] | |
| if waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) | |
| waveform = waveform.to(nemo_codec.device) | |
| # Calculate audio length in samples | |
| audio_len = torch.tensor([waveform.shape[-1]], dtype=torch.int64).to(waveform.device) | |
| # Encode audio to get token codes | |
| with torch.inference_mode(): | |
| encoded_tokens, _ = nemo_codec.encode(audio=waveform, audio_len=audio_len) | |
| # encoded_tokens shape: [batch, num_codebooks, sequence_length] | |
| # For nemo-nano-codec: [1, 4, seq_len] | |
| codes = encoded_tokens[0] # Remove batch dimension -> [4, seq_len] | |
| seq_len = codes.shape[1] | |
| # Flatten the 4 codebook levels per frame (KaniTTS uses 4 tokens per frame) | |
| all_codes = [] | |
| for i in range(seq_len): | |
| # Extract one frame across all 4 codebook levels | |
| for level in range(4): | |
| token_id = codes[level, i].item() | |
| # Add offset for each codebook level | |
| offset_token = token_id + config.tokens.audio_tokens_start + (level * config.tokens.codebook_size) | |
| all_codes.append(offset_token) | |
| return all_codes | |
| def redistribute_codes(code_list): | |
| """ | |
| Decode audio codes back to waveform using Nemo codec. | |
| """ | |
| if len(code_list) % 4 != 0: | |
| print(f"Warning: Code list length {len(code_list)} is not divisible by 4") | |
| return None | |
| num_frames = len(code_list) // 4 | |
| codebook_size = config.tokens.codebook_size | |
| # Separate the 4 codebook levels | |
| level_0 = [] | |
| level_1 = [] | |
| level_2 = [] | |
| level_3 = [] | |
| for i in range(num_frames): | |
| # Extract each level and remove offsets | |
| level_0.append((code_list[4*i] - config.tokens.audio_tokens_start) % codebook_size) | |
| level_1.append((code_list[4*i + 1] - config.tokens.audio_tokens_start - codebook_size) % codebook_size) | |
| level_2.append((code_list[4*i + 2] - config.tokens.audio_tokens_start - 2*codebook_size) % codebook_size) | |
| level_3.append((code_list[4*i + 3] - config.tokens.audio_tokens_start - 3*codebook_size) % codebook_size) | |
| # Convert to tensors in format expected by Nemo: [batch, num_codebooks, sequence_length] | |
| codes = torch.stack([ | |
| torch.tensor(level_0, dtype=torch.long), | |
| torch.tensor(level_1, dtype=torch.long), | |
| torch.tensor(level_2, dtype=torch.long), | |
| torch.tensor(level_3, dtype=torch.long) | |
| ]).unsqueeze(0) # Add batch dimension | |
| try: | |
| # Move to codec device | |
| codes = codes.to(nemo_codec.device) | |
| # Calculate length | |
| tokens_len = torch.tensor([codes.shape[-1]], dtype=torch.int64).to(nemo_codec.device) | |
| # Decode | |
| with torch.no_grad(): | |
| audio_hat, _ = nemo_codec.decode(tokens=codes, tokens_len=tokens_len) | |
| return audio_hat.cpu() | |
| except Exception as e: | |
| print(f"Error decoding audio: {e}") | |
| return None | |
| def transcribe_audio(sample_audio_path, progress=gr.Progress()): | |
| """Transcribe uploaded audio using Whisper.""" | |
| if not sample_audio_path: | |
| gr.Warning("Please upload an audio file first.") | |
| return "" | |
| try: | |
| progress(0, 'Loading audio...') | |
| audio_array, sample_rate = librosa.load(sample_audio_path, sr=config.audio.sample_rate) | |
| # Trim audio to max 15 seconds for transcription | |
| if len(audio_array) / sample_rate > 15: | |
| num_samples_to_keep = int(sample_rate * 15) | |
| audio_array = audio_array[:num_samples_to_keep] | |
| progress(0.5, 'Transcribing...') | |
| transcript = whisper_turbo_pipe(audio_array)['text'].strip() | |
| progress(1, 'Transcription complete!') | |
| return transcript | |
| except Exception as e: | |
| gr.Error(f"Transcription failed: {str(e)}") | |
| return "" | |
| def infer(sample_audio_path, ref_transcript, target_text, temperature, top_p, repetition_penalty, progress=gr.Progress()): | |
| if not target_text or not target_text.strip(): | |
| gr.Warning("Please input text to generate audio.") | |
| return None | |
| if len(target_text) > 500: | |
| gr.Warning("Text is too long. Please keep it under 500 characters.") | |
| target_text = target_text[:500] | |
| target_text = target_text.strip() | |
| if sample_audio_path and (not ref_transcript or not ref_transcript.strip()): | |
| gr.Warning("Please provide a transcript for the reference audio or use the transcribe button.") | |
| return None | |
| with torch.no_grad(): | |
| if sample_audio_path and ref_transcript: | |
| progress(0, 'Loading and trimming audio...') | |
| audio_array, sample_rate = librosa.load(sample_audio_path, sr=config.audio.sample_rate) | |
| # Trim audio to max 15 seconds | |
| if len(audio_array) / sample_rate > 15: | |
| gr.Warning("Trimming audio to first 15secs.") | |
| num_samples_to_keep = int(sample_rate * 15) | |
| audio_array = audio_array[:num_samples_to_keep] | |
| prompt_wav = torch.from_numpy(audio_array).unsqueeze(0) | |
| prompt_wav = prompt_wav.to(dtype=torch.float32) | |
| progress(0.4, 'Encoding reference audio...') | |
| # Encode the prompt wav | |
| voice_tokens = tokenize_audio(prompt_wav) | |
| # Use the provided transcript instead of auto-transcribing | |
| prompt_text = ref_transcript.strip() | |
| progress(0.6, "Generating audio...") | |
| # Tokenize target text | |
| target_text_ids = tokenizer.encode(target_text, add_special_tokens=False) | |
| # Create complete sentence (reference + target) | |
| complete_text = prompt_text + " " + target_text | |
| complete_text_ids = tokenizer.encode(complete_text, add_special_tokens=False) | |
| # Create prompt: Human says complete sentence, AI provides partial audio + continues | |
| prompt_ids = ( | |
| [SOH_ID] | |
| + complete_text_ids # Full sentence as human input | |
| + [EOT_ID] | |
| + [EOH_ID] | |
| + [SOA_ID] | |
| + [SOS_ID] | |
| + voice_tokens # Audio only for reference part | |
| # Model should continue generating audio for the target part | |
| ) | |
| else: | |
| # No reference audio case | |
| prompt_ids = [] | |
| progress(0.6, "Generating audio...") | |
| # Tokenize target text | |
| target_text_ids = tokenizer.encode(target_text, add_special_tokens=False) | |
| # Simple generation without reference | |
| prompt_ids.extend([SOH_ID]) | |
| prompt_ids.extend(target_text_ids) | |
| prompt_ids.extend([EOT_ID]) | |
| prompt_ids.extend([EOH_ID]) | |
| prompt_ids.extend([SOA_ID]) | |
| prompt_ids.extend([SOS_ID]) | |
| print(f"Prompt length: {len(prompt_ids)} tokens") | |
| input_ids = torch.tensor([prompt_ids], dtype=torch.int64).cuda() | |
| # Generate the speech autoregressively | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=config.model.max_new_tokens, | |
| eos_token_id=EOS_ID, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=config.tokens.pad_token, | |
| use_cache=True, | |
| ) | |
| generated_ids = outputs[0].tolist() | |
| print(f"Generated {len(generated_ids)} total tokens") | |
| progress(0.8, "Decoding generated audio...") | |
| # Since we end our prompt with SOS_ID, the generated tokens should be audio tokens directly | |
| # We need to find where our input prompt ends and the generated tokens begin | |
| input_length = len(prompt_ids) | |
| speech_tokens = generated_ids[input_length:] | |
| print(f"Input prompt length: {input_length}, generated tokens: {len(speech_tokens)}") | |
| # Remove end of speech token if present | |
| if EOS_ID in speech_tokens: | |
| speech_tokens = speech_tokens[:speech_tokens.index(EOS_ID)] | |
| if not speech_tokens: | |
| gr.Error("Audio generation failed: No speech tokens were generated.") | |
| return None | |
| # Filter out non-audio tokens | |
| audio_tokens = [token for token in speech_tokens if token >= config.tokens.audio_tokens_start] | |
| if not audio_tokens: | |
| gr.Error("Audio generation failed: No valid audio tokens found.") | |
| return None | |
| print(f"Decoding {len(audio_tokens)} audio tokens") | |
| gen_wav_tensor = redistribute_codes(audio_tokens) | |
| if gen_wav_tensor is None: | |
| gr.Error("Audio decoding failed.") | |
| return None | |
| gen_wav = gen_wav_tensor.squeeze() | |
| progress(1, 'Synthesized!') | |
| return (config.audio.sample_rate, gen_wav.numpy()) | |
| theme = gr.themes.Glass( | |
| primary_hue="cyan", | |
| ) | |
| with gr.Blocks(theme=theme, title="KaniTTS Zero-Shot Voice Cloning") as app_tts: | |
| gr.Markdown("# KaniTTS Zero-Shot Voice Cloning") | |
| gr.Markdown("Upload reference audio, provide its transcript, and enter text to generate speech in the reference voice.") | |
| ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
| with gr.Row(): | |
| ref_transcript_input = gr.Textbox( | |
| label="Reference Audio Transcript", | |
| lines=3, | |
| placeholder="Enter what the reference audio says, or use the transcribe button...", | |
| info="This should match exactly what is said in the reference audio" | |
| ) | |
| transcribe_btn = gr.Button("Transcribe", variant="secondary", size="sm") | |
| gen_text_input = gr.Textbox( | |
| label="Text to Generate", | |
| lines=10, | |
| placeholder="Enter the text you want to generate in the reference voice..." | |
| ) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=1.4, step=0.05, | |
| label="Temperature", | |
| info="Higher values make output more random" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.9, step=0.05, | |
| label="Top-p", | |
| info="Nucleus sampling threshold" | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, maximum=1.5, value=1.1, step=0.05, | |
| label="Repetition Penalty", | |
| info="Penalty for repeating tokens" | |
| ) | |
| generate_btn = gr.Button("Generate Speech", variant="primary") | |
| audio_output = gr.Audio(label="Generated Audio") | |
| # Connect transcribe button | |
| transcribe_btn.click( | |
| transcribe_audio, | |
| inputs=[ref_audio_input], | |
| outputs=[ref_transcript_input], | |
| ) | |
| # Connect generate button | |
| generate_btn.click( | |
| infer, | |
| inputs=[ | |
| ref_audio_input, | |
| ref_transcript_input, | |
| gen_text_input, | |
| temperature_slider, | |
| top_p_slider, | |
| repetition_penalty_slider, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| with gr.Blocks() as app_info: | |
| gr.Markdown(""" | |
| # About KaniTTS | |
| KaniTTS is a conversational text-to-speech model that can perform zero-shot voice cloning. | |
| ## How to use: | |
| 1. Upload a reference audio file (WAV or MP3, max 15 seconds) | |
| 2. Either enter the transcript manually or click "Transcribe" to auto-transcribe | |
| 3. Edit the transcript if needed to ensure accuracy | |
| 4. Enter the text you want to generate in that voice | |
| 5. Adjust generation parameters if needed | |
| 6. Click "Generate Speech" | |
| The model will use your provided transcript to understand the reference voice and generate the target text in the same voice. | |
| ## Tips: | |
| - Use clear, high-quality reference audio | |
| - Keep reference audio under 15 seconds | |
| - The model works best with conversational speech | |
| - Try different temperature settings for varied results | |
| ## Credits: | |
| - KaniTTS model by the KaniTTS team | |
| - Nemo codec by NVIDIA | |
| - Interface adapted from Orpheus TTS demo | |
| """) | |
| with gr.Blocks() as app: | |
| gr.Markdown( | |
| """ | |
| # KaniTTS Zero-Shot Voice Cloning | |
| This is a web interface for KaniTTS zero-shot voice cloning. Upload reference audio and generate speech in any voice! | |
| """ | |
| ) | |
| gr.TabbedInterface([app_tts, app_info], ["Voice Cloning", "About"]) | |
| if __name__ == "__main__": | |
| app.launch() | |