VoiceChat / app.py
legolasyiu's picture
Update app.py
0dddff7 verified
import gradio as gr
import torch
import librosa
import soundfile as sf
import tempfile
from unsloth import FastLanguageModel
import torch
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
AutoTokenizer,
)
from unsloth import FastLanguageModel
# -----------------------------
# CONFIG
# -----------------------------
STT_MODEL_ID = "EpistemeAI/Audiogemma-3N-finetune"
TTS_MODEL_ID = "EpistemeAI/LexiVox"
TARGET_SR = 16000
MAX_TOKENS = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# -----------------------------
# LOAD STT MODEL
# -----------------------------
print("Loading STT model...")
processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
stt_model = AutoModelForImageTextToText.from_pretrained(
STT_MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
stt_model.eval()
# -----------------------------
# LOAD TTS MODEL (UNSLOTH)
# -----------------------------
print("Loading TTS model with Unsloth...")
#tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
tts_model, tts_tokenizer = FastLanguageModel.from_pretrained(
model_name =TTS_MODEL_ID,
max_seq_length= 2048, # Choose any for long context!
dtype = None, # Select None for auto detection
load_in_4bit = False, # Select True for 4bit which reduces memory usage
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
FastLanguageModel.for_inference(tts_model)
tts_model.eval()
# -----------------------------
# STT FUNCTION
# -----------------------------
def transcribe(audio_path):
prompt = "Transcribe the audio accurately in German."
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": audio_path},
{"type": "text", "text": prompt},
],
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
inputs = {k: v.to(stt_model.device) for k, v in inputs.items()}
with torch.inference_mode():
outputs = stt_model.generate(
**inputs,
max_new_tokens=MAX_TOKENS,
do_sample=False,
temperature=0.2,
)
text = processor.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
return text
# -----------------------------
# SPEECH → SPEECH PIPELINE
# -----------------------------
def speech_to_speech(audio_file):
if audio_file is None:
return "", None
# Ensure audio is valid
_audio, _ = librosa.load(audio_file, sr=TARGET_SR)
# ---------- STT ----------
transcription = transcribe(audio_file)
# ---------- TTS ----------
tts_inputs = tts_tokenizer(
transcription,
return_tensors="pt",
).to(tts_model.device)
with torch.inference_mode():
speech_tokens = tts_model.generate(
**tts_inputs,
max_new_tokens=2048,
do_sample=False,
temperature=0.7,
)
audio_out = speech_tokens.cpu().numpy().squeeze()
# Save temporary WAV
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(tmp.name, audio_out, TARGET_SR)
return transcription, tmp.name
# -----------------------------
# GRADIO UI
# -----------------------------
with gr.Blocks(title="Audiogemma → LexiVox (Unsloth)") as demo:
gr.Markdown(
"""
# 🎙️ Speech → Text → Speech
**Audiogemma-3N + LexiVox (Unsloth Accelerated)**
Upload audio or use your microphone.
"""
)
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Input Audio",
)
run_btn = gr.Button("Run Speech Loop")
text_output = gr.Textbox(
label="Transcription",
lines=4,
)
audio_output = gr.Audio(
label="Synthesized Speech",
type="filepath",
)
run_btn.click(
fn=speech_to_speech,
inputs=audio_input,
outputs=[text_output, audio_output],
)
demo.launch()