Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import librosa | |
| import soundfile as sf | |
| import tempfile | |
| from unsloth import FastLanguageModel | |
| import torch | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForImageTextToText, | |
| AutoTokenizer, | |
| ) | |
| from unsloth import FastLanguageModel | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| STT_MODEL_ID = "EpistemeAI/Audiogemma-3N-finetune" | |
| TTS_MODEL_ID = "EpistemeAI/LexiVox" | |
| TARGET_SR = 16000 | |
| MAX_TOKENS = 512 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| # ----------------------------- | |
| # LOAD STT MODEL | |
| # ----------------------------- | |
| print("Loading STT model...") | |
| processor = AutoProcessor.from_pretrained(STT_MODEL_ID) | |
| stt_model = AutoModelForImageTextToText.from_pretrained( | |
| STT_MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| stt_model.eval() | |
| # ----------------------------- | |
| # LOAD TTS MODEL (UNSLOTH) | |
| # ----------------------------- | |
| print("Loading TTS model with Unsloth...") | |
| #tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID) | |
| tts_model, tts_tokenizer = FastLanguageModel.from_pretrained( | |
| model_name =TTS_MODEL_ID, | |
| max_seq_length= 2048, # Choose any for long context! | |
| dtype = None, # Select None for auto detection | |
| load_in_4bit = False, # Select True for 4bit which reduces memory usage | |
| # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf | |
| ) | |
| FastLanguageModel.for_inference(tts_model) | |
| tts_model.eval() | |
| # ----------------------------- | |
| # STT FUNCTION | |
| # ----------------------------- | |
| def transcribe(audio_path): | |
| prompt = "Transcribe the audio accurately in German." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "audio": audio_path}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ) | |
| inputs = {k: v.to(stt_model.device) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| outputs = stt_model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_TOKENS, | |
| do_sample=False, | |
| temperature=0.2, | |
| ) | |
| text = processor.batch_decode( | |
| outputs, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| )[0] | |
| return text | |
| # ----------------------------- | |
| # SPEECH → SPEECH PIPELINE | |
| # ----------------------------- | |
| def speech_to_speech(audio_file): | |
| if audio_file is None: | |
| return "", None | |
| # Ensure audio is valid | |
| _audio, _ = librosa.load(audio_file, sr=TARGET_SR) | |
| # ---------- STT ---------- | |
| transcription = transcribe(audio_file) | |
| # ---------- TTS ---------- | |
| tts_inputs = tts_tokenizer( | |
| transcription, | |
| return_tensors="pt", | |
| ).to(tts_model.device) | |
| with torch.inference_mode(): | |
| speech_tokens = tts_model.generate( | |
| **tts_inputs, | |
| max_new_tokens=2048, | |
| do_sample=False, | |
| temperature=0.7, | |
| ) | |
| audio_out = speech_tokens.cpu().numpy().squeeze() | |
| # Save temporary WAV | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| sf.write(tmp.name, audio_out, TARGET_SR) | |
| return transcription, tmp.name | |
| # ----------------------------- | |
| # GRADIO UI | |
| # ----------------------------- | |
| with gr.Blocks(title="Audiogemma → LexiVox (Unsloth)") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎙️ Speech → Text → Speech | |
| **Audiogemma-3N + LexiVox (Unsloth Accelerated)** | |
| Upload audio or use your microphone. | |
| """ | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Input Audio", | |
| ) | |
| run_btn = gr.Button("Run Speech Loop") | |
| text_output = gr.Textbox( | |
| label="Transcription", | |
| lines=4, | |
| ) | |
| audio_output = gr.Audio( | |
| label="Synthesized Speech", | |
| type="filepath", | |
| ) | |
| run_btn.click( | |
| fn=speech_to_speech, | |
| inputs=audio_input, | |
| outputs=[text_output, audio_output], | |
| ) | |
| demo.launch() | |