Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import uuid | |
| import threading | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| VitsModel, | |
| ) | |
| # ---------------------------- | |
| # Config (CPU-friendly defaults) | |
| # ---------------------------- | |
| ASR_ID = os.environ.get("ASR_ID", "openai/whisper-tiny") # fastest on CPU | |
| LLM_ID = os.environ.get("LLM_ID", "HuggingFaceTB/SmolLM2-135M-Instruct") | |
| TTS_ID = os.environ.get("TTS_ID", "facebook/mms-tts-eng") | |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "120")) # keep short for latency | |
| MIN_NEW_TOKENS = int(os.environ.get("MIN_NEW_TOKENS", "20")) | |
| OUT_DIR = "outputs" | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| # ---------------------------- | |
| # Global singletons (loaded once) | |
| # ---------------------------- | |
| _load_lock = threading.Lock() | |
| _asr = None | |
| _llm_tok = None | |
| _llm = None | |
| _tts_tok = None | |
| _tts = None | |
| _tts_sr = None | |
| def _now_ms() -> float: | |
| return time.perf_counter() * 1000.0 | |
| def load_models(): | |
| """Load all models once per Space container.""" | |
| global _asr, _llm_tok, _llm, _tts_tok, _tts, _tts_sr | |
| if _asr is not None and _llm is not None and _tts is not None: | |
| return | |
| with _load_lock: | |
| if _asr is None: | |
| # CPU-only (Spaces free tier) | |
| _asr = pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_ID, | |
| device=-1, | |
| ) | |
| if _llm is None or _llm_tok is None: | |
| _llm_tok = AutoTokenizer.from_pretrained(LLM_ID) | |
| _llm = AutoModelForCausalLM.from_pretrained( | |
| LLM_ID, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| _llm.eval() | |
| if _tts is None or _tts_tok is None: | |
| _tts_tok = AutoTokenizer.from_pretrained(TTS_ID) | |
| _tts = VitsModel.from_pretrained( | |
| TTS_ID, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| _tts.eval() | |
| _tts_sr = int(_tts.config.sampling_rate) | |
| def _clean_asr_text(s: str) -> str: | |
| s = (s or "").strip() | |
| if s.lower().startswith("question,"): | |
| s = s[len("question,"):].strip() | |
| return s | |
| def _llm_answer_from_text(user_text: str) -> str: | |
| """Very small, reliable prompt wrapper for tiny instruct models.""" | |
| user_text = _clean_asr_text(user_text) | |
| if not user_text: | |
| return "I didn't catch that. Please repeat your question." | |
| # Use chat template if available (best), else minimal wrapper | |
| if hasattr(_llm_tok, "apply_chat_template"): | |
| messages = [{"role": "user", "content": user_text}] | |
| prompt = _llm_tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| else: | |
| prompt = f"User: {user_text}\nAssistant:" | |
| inputs = _llm_tok(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| gen = _llm.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| min_new_tokens=MIN_NEW_TOKENS, | |
| do_sample=False, | |
| eos_token_id=_llm_tok.eos_token_id, | |
| pad_token_id=_llm_tok.eos_token_id, | |
| ) | |
| full = _llm_tok.decode(gen[0], skip_special_tokens=True) | |
| # Try to extract assistant portion | |
| if "Assistant:" in full: | |
| ans = full.split("Assistant:", 1)[-1].strip() | |
| else: | |
| ans = full.strip() | |
| # If it echoed the prompt, strip the prompt prefix crudely | |
| if ans.startswith(prompt): | |
| ans = ans[len(prompt):].strip() | |
| return ans if ans else "I produced no answer. Please try again." | |
| def _tts_speak(text: str, out_wav_path: str) -> str: | |
| text = (text or "").strip() | |
| if not text: | |
| text = "I have no text to speak." | |
| inputs = _tts_tok(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| wav = _tts(**inputs).waveform | |
| wav = wav.squeeze().detach().cpu().numpy().astype(np.float32) | |
| sf.write(out_wav_path, wav, _tts_sr) | |
| return out_wav_path | |
| def voice_qa(audio_path: str): | |
| """ | |
| Gradio passes a filepath for Audio(type="filepath"). | |
| Return: | |
| transcript, answer, tts_audio_path, debug_text, transcript_file, answer_file | |
| """ | |
| load_models() | |
| run_id = time.strftime("%Y%m%d-%H%M%S") + "_" + str(uuid.uuid4())[:8] | |
| run_dir = os.path.join(OUT_DIR, run_id) | |
| os.makedirs(run_dir, exist_ok=True) | |
| transcript_file = os.path.join(run_dir, "transcript.txt") | |
| answer_file = os.path.join(run_dir, "answer.txt") | |
| tts_file = os.path.join(run_dir, "tts_answer.wav") | |
| dbg_lines = [] | |
| t0 = _now_ms() | |
| # --- ASR --- | |
| t_asr0 = _now_ms() | |
| # return_timestamps=True avoids Whisper long-form errors for >30s files | |
| asr_out = _asr(audio_path, return_timestamps=True) | |
| transcript = _clean_asr_text(asr_out.get("text", "")) | |
| t_asr1 = _now_ms() | |
| with open(transcript_file, "w", encoding="utf-8") as f: | |
| f.write(transcript) | |
| dbg_lines.append(f"[ASR] model={ASR_ID}") | |
| dbg_lines.append(f"[ASR] ms={(t_asr1 - t_asr0):.1f}") | |
| dbg_lines.append(f"[ASR] chars={len(transcript)}") | |
| # --- LLM --- | |
| t_llm0 = _now_ms() | |
| answer = _llm_answer_from_text(transcript) | |
| t_llm1 = _now_ms() | |
| with open(answer_file, "w", encoding="utf-8") as f: | |
| f.write(answer) | |
| dbg_lines.append(f"[LLM] model={LLM_ID}") | |
| dbg_lines.append(f"[LLM] ms={(t_llm1 - t_llm0):.1f}") | |
| dbg_lines.append(f"[LLM] chars={len(answer)}") | |
| # --- TTS --- | |
| t_tts0 = _now_ms() | |
| _tts_speak(answer, tts_file) | |
| t_tts1 = _now_ms() | |
| dbg_lines.append(f"[TTS] model={TTS_ID}") | |
| dbg_lines.append(f"[TTS] ms={(t_tts1 - t_tts0):.1f}") | |
| dbg_lines.append(f"[TTS] out={tts_file}") | |
| t1 = _now_ms() | |
| dbg_lines.append(f"[TOTAL] ms={(t1 - t0):.1f}") | |
| debug_text = "\n".join(dbg_lines) | |
| return transcript, answer, tts_file, debug_text, transcript_file, answer_file | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks(title="Voice Q&A (ASR β LLM β TTS)") as demo: | |
| gr.Markdown( | |
| "# Voice Q&A (ASR β LLM β TTS)\n" | |
| "Speak a question β it transcribes β answers β speaks back.\n\n" | |
| "**CPU-friendly defaults**: Whisper *tiny* + SmolLM2-135M + MMS TTS.\n" | |
| ) | |
| with gr.Row(): | |
| audio_in = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Microphone input", | |
| ) | |
| run_btn = gr.Button("Run (ASR β LLM β TTS)", variant="primary") | |
| with gr.Row(): | |
| transcript_out = gr.Textbox(label="Transcript (ASR)", lines=4) | |
| answer_out = gr.Textbox(label="Answer (LLM)", lines=6) | |
| tts_out = gr.Audio(label="Spoken answer (TTS)", type="filepath") | |
| debug_out = gr.Textbox(label="Debug / timings", lines=10) | |
| with gr.Row(): | |
| transcript_dl = gr.File(label="Download transcript.txt") | |
| answer_dl = gr.File(label="Download answer.txt") | |
| run_btn.click( | |
| fn=voice_qa, | |
| inputs=[audio_in], | |
| outputs=[transcript_out, answer_out, tts_out, debug_out, transcript_dl, answer_dl], | |
| ) | |
| gr.Markdown( | |
| "### Notes\n" | |
| "- If latency is still high on free CPU, try even shorter questions (2β5 seconds).\n" | |
| "- You can switch ASR model by setting Space variables: `ASR_ID=openai/whisper-base` (better) or keep `whisper-tiny` (faster).\n" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |