Spaces:
Running
Running
File size: 5,918 Bytes
c812b33 4d66d17 c812b33 589f64a e28dfb5 c812b33 649294c c812b33 1c04da9 c812b33 4d66d17 c812b33 4d66d17 1c04da9 4d66d17 c812b33 c9304ac 6303154 c812b33 6303154 c812b33 6303154 c9304ac 6303154 c9304ac c812b33 c9304ac 6303154 c812b33 c9304ac c812b33 c832129 c812b33 c9304ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gradio as gr
import subprocess, json, os, io, tempfile
from faster_whisper import WhisperModel
from ollama import Client as OllamaClient
# ---- CONFIG ----
LLM_MODEL = "llama3.2:3b" # or "mistral:7b", "qwen2.5:3b"
WHISPER_SIZE = "small" # "base", "small", "medium"
USE_SILERO = True # set False to use Coqui XTTS v2
USE_CONTEXT = False # <— new: disable conversational memory
import os
USE_REMOTE_OLLAMA = bool(os.getenv("OLLAMA_HOST"))
if not USE_REMOTE_OLLAMA:
# Transformers fallback for Spaces (CPU-friendly small instruct model)
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
HF_CHAT_MODEL = os.getenv("HF_CHAT_MODEL", "google/gemma-2-2b-it") # small instruct model that runs on CPU
HF_TOKEN = os.getenv("HF_TOKEN")
_tok = AutoTokenizer.from_pretrained(HF_CHAT_MODEL, token=HF_TOKEN)
_mdl = AutoModelForCausalLM.from_pretrained(HF_CHAT_MODEL, token=HF_TOKEN, torch_dtype="auto", device_map="auto")
gen = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=256)
# ---- STT (faster-whisper) ----
# Run on GPU if available: compute_type="float16", device="cuda"
stt_model = WhisperModel(WHISPER_SIZE, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
compute_type="float16" if os.environ.get("CUDA_VISIBLE_DEVICES") else "int8")
def speech_to_text(audio_path: str) -> str:
segments, info = stt_model.transcribe(audio_path, beam_size=1, vad_filter=True, language="en")
text = "".join(seg.text for seg in segments).strip()
return text
# ---- LLM (Ollama) ----
# ollama = OllamaClient(host="http://127.0.0.1:11434")
SYSTEM_PROMPT = """You are a friendly AI voice assistant.
Reply in one short, natural sentence only.
Sound warm and conversational, never formal.
Avoid multi-sentence or paragraph answers."""
def chat_with_llm(history_messages, user_text):
if USE_REMOTE_OLLAMA:
# Only system + current user
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_text},
]
resp = ollama.chat(model=LLM_MODEL, messages=messages)
return resp["message"]["content"]
else:
# Only system + current user
prompt = f"{SYSTEM_PROMPT}\nUser: {user_text}\nAssistant:"
out = gen(prompt, return_full_text=False, max_new_tokens=25, temperature=0.8, repetition_penalty=1.1,)[0]["generated_text"].split("\n")[0].strip()
return out
# near top-level (global singletons)
_SILERO_TTS = None
def tts_silero(text: str) -> str:
"""
Return path to WAV synthesized by Silero TTS.
Uses a cached model instance to avoid re-downloading each request.
"""
import torch, tempfile
import soundfile as sf
global _SILERO_TTS
if _SILERO_TTS is None:
obj = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_tts",
language="en",
speaker="v3_en",
trust_repo=True, # avoids interactive trust prompt
)
_SILERO_TTS = obj[0] if isinstance(obj, (list, tuple)) else obj
model = _SILERO_TTS
sample_rate = 48000
speaker = "en_0"
audio = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
out_wav = tempfile.mktemp(suffix=".wav")
sf.write(out_wav, audio, sample_rate)
return out_wav
def tts_coqui_xtts(text: str) -> str:
"""
Returns path to a WAV file synthesized by Coqui XTTS v2 (higher quality; GPU-friendly).
"""
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
out_wav = tempfile.mktemp(suffix=".wav")
tts.tts_to_file(text=text, file_path=out_wav, speaker="female-en-5", language="en")
return out_wav
def text_to_speech(text: str) -> str:
if USE_SILERO:
return tts_silero(text)
else:
return tts_coqui_xtts(text)
# ---- Gradio pipeline ----
def pipeline(audio, history):
# audio is (sample_rate, np.array) OR a filepath (depends on Gradio version)
# Normalize to a temp wav file
if audio is None:
return history, None, "Please speak something."
if isinstance(audio, tuple):
# (sr, data) -> write wav
import soundfile as sf, numpy as np, tempfile
sr, data = audio
tmp_in = tempfile.mktemp(suffix=".wav")
sf.write(tmp_in, data.astype("float32"), sr)
audio_path = tmp_in
else:
audio_path = audio # path already
user_text = speech_to_text(audio_path)
if not user_text:
return history, None, "Didn't catch that—could you repeat?"
reply = chat_with_llm(history, user_text)
# Extract the "Reply:" line for TTS; speak only the conversational reply
speak_text = reply
for tag in ["Reply:", "Correction:", "Why:"]:
# Try to find "Reply:" block
if "Reply:" in reply:
speak_text = reply.split("Reply:", 1)[1].strip()
break
wav_path = text_to_speech(speak_text)
updated = (history or []) + [
{"role": "user", "content": user_text},
{"role": "assistant", "content": reply},
]
return updated, wav_path, ""
with gr.Blocks(title="Voice Coach") as demo:
gr.Markdown("## 🎙️ Interactive Voice Chat")
with gr.Row():
audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Speak")
audio_out = gr.Audio(label="Assistant (TTS)", autoplay=True)
chatbox = gr.Chatbot(type="messages", height=300)
status = gr.Markdown()
btn = gr.Button("Send")
# Use continuous recording or press "Send" after recording
audio_in.change(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
btn.click(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
if __name__ == "__main__":
demo.launch()
|