Javedalam's picture
Create app.py
ecd22c7 verified
import os
import time
import uuid
import threading
import gradio as gr
import numpy as np
import torch
import soundfile as sf
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
AutoProcessor,
VitsModel,
)
# ----------------------------
# Config (CPU-friendly defaults)
# ----------------------------
ASR_ID = os.environ.get("ASR_ID", "openai/whisper-tiny") # fastest on CPU
LLM_ID = os.environ.get("LLM_ID", "HuggingFaceTB/SmolLM2-135M-Instruct")
TTS_ID = os.environ.get("TTS_ID", "facebook/mms-tts-eng")
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "120")) # keep short for latency
MIN_NEW_TOKENS = int(os.environ.get("MIN_NEW_TOKENS", "20"))
OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)
# ----------------------------
# Global singletons (loaded once)
# ----------------------------
_load_lock = threading.Lock()
_asr = None
_llm_tok = None
_llm = None
_tts_tok = None
_tts = None
_tts_sr = None
def _now_ms() -> float:
return time.perf_counter() * 1000.0
def load_models():
"""Load all models once per Space container."""
global _asr, _llm_tok, _llm, _tts_tok, _tts, _tts_sr
if _asr is not None and _llm is not None and _tts is not None:
return
with _load_lock:
if _asr is None:
# CPU-only (Spaces free tier)
_asr = pipeline(
"automatic-speech-recognition",
model=ASR_ID,
device=-1,
)
if _llm is None or _llm_tok is None:
_llm_tok = AutoTokenizer.from_pretrained(LLM_ID)
_llm = AutoModelForCausalLM.from_pretrained(
LLM_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
_llm.eval()
if _tts is None or _tts_tok is None:
_tts_tok = AutoTokenizer.from_pretrained(TTS_ID)
_tts = VitsModel.from_pretrained(
TTS_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
_tts.eval()
_tts_sr = int(_tts.config.sampling_rate)
def _clean_asr_text(s: str) -> str:
s = (s or "").strip()
if s.lower().startswith("question,"):
s = s[len("question,"):].strip()
return s
def _llm_answer_from_text(user_text: str) -> str:
"""Very small, reliable prompt wrapper for tiny instruct models."""
user_text = _clean_asr_text(user_text)
if not user_text:
return "I didn't catch that. Please repeat your question."
# Use chat template if available (best), else minimal wrapper
if hasattr(_llm_tok, "apply_chat_template"):
messages = [{"role": "user", "content": user_text}]
prompt = _llm_tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
prompt = f"User: {user_text}\nAssistant:"
inputs = _llm_tok(prompt, return_tensors="pt")
with torch.no_grad():
gen = _llm.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
min_new_tokens=MIN_NEW_TOKENS,
do_sample=False,
eos_token_id=_llm_tok.eos_token_id,
pad_token_id=_llm_tok.eos_token_id,
)
full = _llm_tok.decode(gen[0], skip_special_tokens=True)
# Try to extract assistant portion
if "Assistant:" in full:
ans = full.split("Assistant:", 1)[-1].strip()
else:
ans = full.strip()
# If it echoed the prompt, strip the prompt prefix crudely
if ans.startswith(prompt):
ans = ans[len(prompt):].strip()
return ans if ans else "I produced no answer. Please try again."
def _tts_speak(text: str, out_wav_path: str) -> str:
text = (text or "").strip()
if not text:
text = "I have no text to speak."
inputs = _tts_tok(text, return_tensors="pt")
with torch.no_grad():
wav = _tts(**inputs).waveform
wav = wav.squeeze().detach().cpu().numpy().astype(np.float32)
sf.write(out_wav_path, wav, _tts_sr)
return out_wav_path
def voice_qa(audio_path: str):
"""
Gradio passes a filepath for Audio(type="filepath").
Return:
transcript, answer, tts_audio_path, debug_text, transcript_file, answer_file
"""
load_models()
run_id = time.strftime("%Y%m%d-%H%M%S") + "_" + str(uuid.uuid4())[:8]
run_dir = os.path.join(OUT_DIR, run_id)
os.makedirs(run_dir, exist_ok=True)
transcript_file = os.path.join(run_dir, "transcript.txt")
answer_file = os.path.join(run_dir, "answer.txt")
tts_file = os.path.join(run_dir, "tts_answer.wav")
dbg_lines = []
t0 = _now_ms()
# --- ASR ---
t_asr0 = _now_ms()
# return_timestamps=True avoids Whisper long-form errors for >30s files
asr_out = _asr(audio_path, return_timestamps=True)
transcript = _clean_asr_text(asr_out.get("text", ""))
t_asr1 = _now_ms()
with open(transcript_file, "w", encoding="utf-8") as f:
f.write(transcript)
dbg_lines.append(f"[ASR] model={ASR_ID}")
dbg_lines.append(f"[ASR] ms={(t_asr1 - t_asr0):.1f}")
dbg_lines.append(f"[ASR] chars={len(transcript)}")
# --- LLM ---
t_llm0 = _now_ms()
answer = _llm_answer_from_text(transcript)
t_llm1 = _now_ms()
with open(answer_file, "w", encoding="utf-8") as f:
f.write(answer)
dbg_lines.append(f"[LLM] model={LLM_ID}")
dbg_lines.append(f"[LLM] ms={(t_llm1 - t_llm0):.1f}")
dbg_lines.append(f"[LLM] chars={len(answer)}")
# --- TTS ---
t_tts0 = _now_ms()
_tts_speak(answer, tts_file)
t_tts1 = _now_ms()
dbg_lines.append(f"[TTS] model={TTS_ID}")
dbg_lines.append(f"[TTS] ms={(t_tts1 - t_tts0):.1f}")
dbg_lines.append(f"[TTS] out={tts_file}")
t1 = _now_ms()
dbg_lines.append(f"[TOTAL] ms={(t1 - t0):.1f}")
debug_text = "\n".join(dbg_lines)
return transcript, answer, tts_file, debug_text, transcript_file, answer_file
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="Voice Q&A (ASR β†’ LLM β†’ TTS)") as demo:
gr.Markdown(
"# Voice Q&A (ASR β†’ LLM β†’ TTS)\n"
"Speak a question β†’ it transcribes β†’ answers β†’ speaks back.\n\n"
"**CPU-friendly defaults**: Whisper *tiny* + SmolLM2-135M + MMS TTS.\n"
)
with gr.Row():
audio_in = gr.Audio(
sources=["microphone"],
type="filepath",
label="Microphone input",
)
run_btn = gr.Button("Run (ASR β†’ LLM β†’ TTS)", variant="primary")
with gr.Row():
transcript_out = gr.Textbox(label="Transcript (ASR)", lines=4)
answer_out = gr.Textbox(label="Answer (LLM)", lines=6)
tts_out = gr.Audio(label="Spoken answer (TTS)", type="filepath")
debug_out = gr.Textbox(label="Debug / timings", lines=10)
with gr.Row():
transcript_dl = gr.File(label="Download transcript.txt")
answer_dl = gr.File(label="Download answer.txt")
run_btn.click(
fn=voice_qa,
inputs=[audio_in],
outputs=[transcript_out, answer_out, tts_out, debug_out, transcript_dl, answer_dl],
)
gr.Markdown(
"### Notes\n"
"- If latency is still high on free CPU, try even shorter questions (2–5 seconds).\n"
"- You can switch ASR model by setting Space variables: `ASR_ID=openai/whisper-base` (better) or keep `whisper-tiny` (faster).\n"
)
if __name__ == "__main__":
demo.launch()