import os import time import uuid import threading import gradio as gr import numpy as np import torch import soundfile as sf from transformers import ( pipeline, AutoTokenizer, AutoModelForCausalLM, AutoProcessor, VitsModel, ) # ---------------------------- # Config (CPU-friendly defaults) # ---------------------------- ASR_ID = os.environ.get("ASR_ID", "openai/whisper-tiny") # fastest on CPU LLM_ID = os.environ.get("LLM_ID", "HuggingFaceTB/SmolLM2-135M-Instruct") TTS_ID = os.environ.get("TTS_ID", "facebook/mms-tts-eng") MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "120")) # keep short for latency MIN_NEW_TOKENS = int(os.environ.get("MIN_NEW_TOKENS", "20")) OUT_DIR = "outputs" os.makedirs(OUT_DIR, exist_ok=True) # ---------------------------- # Global singletons (loaded once) # ---------------------------- _load_lock = threading.Lock() _asr = None _llm_tok = None _llm = None _tts_tok = None _tts = None _tts_sr = None def _now_ms() -> float: return time.perf_counter() * 1000.0 def load_models(): """Load all models once per Space container.""" global _asr, _llm_tok, _llm, _tts_tok, _tts, _tts_sr if _asr is not None and _llm is not None and _tts is not None: return with _load_lock: if _asr is None: # CPU-only (Spaces free tier) _asr = pipeline( "automatic-speech-recognition", model=ASR_ID, device=-1, ) if _llm is None or _llm_tok is None: _llm_tok = AutoTokenizer.from_pretrained(LLM_ID) _llm = AutoModelForCausalLM.from_pretrained( LLM_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True, ) _llm.eval() if _tts is None or _tts_tok is None: _tts_tok = AutoTokenizer.from_pretrained(TTS_ID) _tts = VitsModel.from_pretrained( TTS_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True, ) _tts.eval() _tts_sr = int(_tts.config.sampling_rate) def _clean_asr_text(s: str) -> str: s = (s or "").strip() if s.lower().startswith("question,"): s = s[len("question,"):].strip() return s def _llm_answer_from_text(user_text: str) -> str: """Very small, reliable prompt wrapper for tiny instruct models.""" user_text = _clean_asr_text(user_text) if not user_text: return "I didn't catch that. Please repeat your question." # Use chat template if available (best), else minimal wrapper if hasattr(_llm_tok, "apply_chat_template"): messages = [{"role": "user", "content": user_text}] prompt = _llm_tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: prompt = f"User: {user_text}\nAssistant:" inputs = _llm_tok(prompt, return_tensors="pt") with torch.no_grad(): gen = _llm.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, min_new_tokens=MIN_NEW_TOKENS, do_sample=False, eos_token_id=_llm_tok.eos_token_id, pad_token_id=_llm_tok.eos_token_id, ) full = _llm_tok.decode(gen[0], skip_special_tokens=True) # Try to extract assistant portion if "Assistant:" in full: ans = full.split("Assistant:", 1)[-1].strip() else: ans = full.strip() # If it echoed the prompt, strip the prompt prefix crudely if ans.startswith(prompt): ans = ans[len(prompt):].strip() return ans if ans else "I produced no answer. Please try again." def _tts_speak(text: str, out_wav_path: str) -> str: text = (text or "").strip() if not text: text = "I have no text to speak." inputs = _tts_tok(text, return_tensors="pt") with torch.no_grad(): wav = _tts(**inputs).waveform wav = wav.squeeze().detach().cpu().numpy().astype(np.float32) sf.write(out_wav_path, wav, _tts_sr) return out_wav_path def voice_qa(audio_path: str): """ Gradio passes a filepath for Audio(type="filepath"). Return: transcript, answer, tts_audio_path, debug_text, transcript_file, answer_file """ load_models() run_id = time.strftime("%Y%m%d-%H%M%S") + "_" + str(uuid.uuid4())[:8] run_dir = os.path.join(OUT_DIR, run_id) os.makedirs(run_dir, exist_ok=True) transcript_file = os.path.join(run_dir, "transcript.txt") answer_file = os.path.join(run_dir, "answer.txt") tts_file = os.path.join(run_dir, "tts_answer.wav") dbg_lines = [] t0 = _now_ms() # --- ASR --- t_asr0 = _now_ms() # return_timestamps=True avoids Whisper long-form errors for >30s files asr_out = _asr(audio_path, return_timestamps=True) transcript = _clean_asr_text(asr_out.get("text", "")) t_asr1 = _now_ms() with open(transcript_file, "w", encoding="utf-8") as f: f.write(transcript) dbg_lines.append(f"[ASR] model={ASR_ID}") dbg_lines.append(f"[ASR] ms={(t_asr1 - t_asr0):.1f}") dbg_lines.append(f"[ASR] chars={len(transcript)}") # --- LLM --- t_llm0 = _now_ms() answer = _llm_answer_from_text(transcript) t_llm1 = _now_ms() with open(answer_file, "w", encoding="utf-8") as f: f.write(answer) dbg_lines.append(f"[LLM] model={LLM_ID}") dbg_lines.append(f"[LLM] ms={(t_llm1 - t_llm0):.1f}") dbg_lines.append(f"[LLM] chars={len(answer)}") # --- TTS --- t_tts0 = _now_ms() _tts_speak(answer, tts_file) t_tts1 = _now_ms() dbg_lines.append(f"[TTS] model={TTS_ID}") dbg_lines.append(f"[TTS] ms={(t_tts1 - t_tts0):.1f}") dbg_lines.append(f"[TTS] out={tts_file}") t1 = _now_ms() dbg_lines.append(f"[TOTAL] ms={(t1 - t0):.1f}") debug_text = "\n".join(dbg_lines) return transcript, answer, tts_file, debug_text, transcript_file, answer_file # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks(title="Voice Q&A (ASR → LLM → TTS)") as demo: gr.Markdown( "# Voice Q&A (ASR → LLM → TTS)\n" "Speak a question → it transcribes → answers → speaks back.\n\n" "**CPU-friendly defaults**: Whisper *tiny* + SmolLM2-135M + MMS TTS.\n" ) with gr.Row(): audio_in = gr.Audio( sources=["microphone"], type="filepath", label="Microphone input", ) run_btn = gr.Button("Run (ASR → LLM → TTS)", variant="primary") with gr.Row(): transcript_out = gr.Textbox(label="Transcript (ASR)", lines=4) answer_out = gr.Textbox(label="Answer (LLM)", lines=6) tts_out = gr.Audio(label="Spoken answer (TTS)", type="filepath") debug_out = gr.Textbox(label="Debug / timings", lines=10) with gr.Row(): transcript_dl = gr.File(label="Download transcript.txt") answer_dl = gr.File(label="Download answer.txt") run_btn.click( fn=voice_qa, inputs=[audio_in], outputs=[transcript_out, answer_out, tts_out, debug_out, transcript_dl, answer_dl], ) gr.Markdown( "### Notes\n" "- If latency is still high on free CPU, try even shorter questions (2–5 seconds).\n" "- You can switch ASR model by setting Space variables: `ASR_ID=openai/whisper-base` (better) or keep `whisper-tiny` (faster).\n" ) if __name__ == "__main__": demo.launch()